提交 201f19e4 编写于 作者: T TensorFlower Gardener

Merge pull request #34685 from ljwh:master

PiperOrigin-RevId: 286253056
Change-Id: I9928d21effe92bed1a9f131a47d9e172cd970330
......@@ -823,6 +823,14 @@ RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step)
stop_(stop),
step_(step) {}
RangeDatasetParams::RangeDatasetParams(int64 start, int64 stop, int64 step,
DataTypeVector output_dtypes)
: DatasetParams(std::move(output_dtypes), {PartialTensorShape({})},
"range_dataset"),
start_(start),
stop_(stop),
step_(step) {}
std::vector<Tensor> RangeDatasetParams::GetInputTensors() const {
Tensor start_tensor = CreateTensor<int64>(TensorShape({}), {start_});
Tensor stop_tensor = CreateTensor<int64>(TensorShape({}), {stop_});
......
......@@ -172,6 +172,9 @@ class RangeDatasetParams : public DatasetParams {
RangeDatasetParams(int64 start, int64 stop, int64 step);
RangeDatasetParams(int64 start, int64 stop, int64 step,
DataTypeVector output_dtypes);
std::vector<Tensor> GetInputTensors() const override;
Status GetInputNames(std::vector<string>* input_names) const override;
......
......@@ -36,11 +36,13 @@ constexpr char kNext[] = "next";
class RangeDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step,
DataTypeVector output_dtypes)
: DatasetBase(DatasetContext(ctx)),
start_(start),
stop_(stop),
step_(step) {}
step_(step),
output_dtypes_(output_dtypes) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
......@@ -49,8 +51,7 @@ class RangeDatasetOp::Dataset : public DatasetBase {
}
const DataTypeVector& output_dtypes() const override {
static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
return *dtypes;
return output_dtypes_;
}
const std::vector<PartialTensorShape>& output_shapes() const override {
......@@ -106,7 +107,20 @@ class RangeDatasetOp::Dataset : public DatasetBase {
return Status::OK();
}
out_tensors->reserve(1);
out_tensors->emplace_back(next_);
Tensor result(dataset()->output_dtypes()[0]);
switch (dataset()->output_dtypes()[0]) {
#define HANDLE_TYPE(type) \
case DataTypeToEnum<type>::value: { \
out_tensors->emplace_back(static_cast<type>(next_)); \
break; \
}
TF_CALL_NUMBER_TYPES(HANDLE_TYPE);
#undef HANDLE_TYPE
default:
return errors::InvalidArgument(
"Unsupported data type: ",
DataTypeString(dataset()->output_dtypes()[0]));
}
*end_of_sequence = false;
next_ += dataset()->step_;
......@@ -140,10 +154,13 @@ class RangeDatasetOp::Dataset : public DatasetBase {
const int64 start_;
const int64 stop_;
const int64 step_;
const DataTypeVector output_dtypes_;
};
RangeDatasetOp::RangeDatasetOp(OpKernelConstruction* ctx)
: DatasetOpKernel(ctx) {}
: DatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr(kOutputTypes, &output_types_));
}
void RangeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
int64 start;
......@@ -157,7 +174,7 @@ void RangeDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) {
OP_REQUIRES(ctx, step != 0,
errors::InvalidArgument("step must be a non-zero integer."));
*output = new Dataset(ctx, start, stop, step);
*output = new Dataset(ctx, start, stop, step, output_types_);
}
namespace {
......
......@@ -36,6 +36,7 @@ class RangeDatasetOp : public DatasetOpKernel {
private:
class Dataset;
DataTypeVector output_types_;
};
} // namespace data
......
......@@ -34,6 +34,16 @@ RangeDatasetParams ZeroStepRangeDatasetParams() {
return RangeDatasetParams(/*start=*/10, /*stop=*/0, /*step=*/0);
}
RangeDatasetParams RangeDatasetParams1() {
return RangeDatasetParams(/*start=*/0, /*stop=*/10, /*step=*/3,
/*output_dtypes=*/{DT_INT32});
}
RangeDatasetParams RangeDatasetParams2() {
return RangeDatasetParams(/*start=*/0, /*stop=*/10, /*step=*/3,
/*output_dtypes=*/{DT_INT64});
}
std::vector<GetNextTestCase<RangeDatasetParams>> GetNextTestCases() {
return {{/*dataset_params=*/PositiveStepRangeDatasetParams(),
/*expected_outputs=*/
......@@ -59,12 +69,17 @@ TEST_F(RangeDatasetOpTest, DatasetTypeString) {
CheckDatasetTypeString(name_utils::OpName(RangeDatasetOp::kDatasetType)));
}
TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
TF_ASSERT_OK(Initialize(range_dataset_params));
TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
std::vector<DatasetOutputDtypesTestCase<RangeDatasetParams>>
DatasetOutputDtypesTestCases() {
return {{/*dataset_params=*/RangeDatasetParams1(),
/*expected_output_dtypes=*/{DT_INT32}},
{/*dataset_params=*/RangeDatasetParams2(),
/*expected_output_dtypes=*/{DT_INT64}}};
}
DATASET_OUTPUT_DTYPES_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
DatasetOutputDtypesTestCases())
TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
TF_ASSERT_OK(Initialize(range_dataset_params));
......@@ -81,12 +96,17 @@ std::vector<CardinalityTestCase<RangeDatasetParams>> CardinalityTestCases() {
DATASET_CARDINALITY_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
CardinalityTestCases())
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
TF_ASSERT_OK(Initialize(range_dataset_params));
TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
std::vector<IteratorOutputDtypesTestCase<RangeDatasetParams>>
IteratorOutputDtypesTestCases() {
return {{/*dataset_params=*/RangeDatasetParams1(),
/*expected_output_dtypes=*/{DT_INT32}},
{/*dataset_params=*/RangeDatasetParams2(),
/*expected_output_dtypes=*/{DT_INT64}}};
}
ITERATOR_OUTPUT_DTYPES_TEST_P(RangeDatasetOpTest, RangeDatasetParams,
IteratorOutputDtypesTestCases())
TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
auto range_dataset_params = PositiveStepRangeDatasetParams();
TF_ASSERT_OK(Initialize(range_dataset_params));
......
......@@ -18,63 +18,129 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
class RangeTest(test_base.DatasetTestBase, parameterized.TestCase):
@combinations.generate(test_base.default_test_combinations())
def testStop(self):
dataset = dataset_ops.Dataset.range(5)
self.assertDatasetProduces(dataset, expected_output=range(5))
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStop(self, output_type):
stop = 5
dataset = dataset_ops.Dataset.range(stop, output_type=output_type)
expected_output = np.arange(stop, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testStartStop(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStartStop(self, output_type):
start, stop = 2, 5
dataset = dataset_ops.Dataset.range(start, stop)
self.assertDatasetProduces(dataset, expected_output=range(2, 5))
dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type)
expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testStartStopStep(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStartStopStep(self, output_type):
start, stop, step = 2, 10, 2
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(2, 10, 2))
dataset = dataset_ops.Dataset.range(
start, stop, step, output_type=output_type)
expected_output = np.arange(
start, stop, step, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testZeroStep(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testZeroStep(self, output_type):
start, stop, step = 2, 10, 0
with self.assertRaises(errors.InvalidArgumentError):
dataset = dataset_ops.Dataset.range(start, stop, step)
dataset = dataset_ops.Dataset.range(
start, stop, step, output_type=output_type)
self.evaluate(dataset._variant_tensor)
@combinations.generate(test_base.default_test_combinations())
def testNegativeStep(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testNegativeStep(self, output_type):
start, stop, step = 2, 10, -1
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(2, 10, -1))
dataset = dataset_ops.Dataset.range(
start, stop, step, output_type=output_type)
expected_output = np.arange(
start, stop, step, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStart(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStopLessThanStart(self, output_type):
start, stop = 10, 2
dataset = dataset_ops.Dataset.range(start, stop)
self.assertDatasetProduces(dataset, expected_output=range(10, 2))
dataset = dataset_ops.Dataset.range(start, stop, output_type=output_type)
expected_output = np.arange(start, stop, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStartWithPositiveStep(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStopLessThanStartWithPositiveStep(self, output_type):
start, stop, step = 10, 2, 2
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(10, 2, 2))
dataset = dataset_ops.Dataset.range(
start, stop, step, output_type=output_type)
expected_output = np.arange(
start, stop, step, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
@combinations.generate(test_base.default_test_combinations())
def testStopLessThanStartWithNegativeStep(self):
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(output_type=[
dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64
])))
def testStopLessThanStartWithNegativeStep(self, output_type):
start, stop, step = 10, 2, -1
dataset = dataset_ops.Dataset.range(start, stop, step)
self.assertDatasetProduces(dataset, expected_output=range(10, 2, -1))
dataset = dataset_ops.Dataset.range(
start, stop, step, output_type=output_type)
expected_output = np.arange(
start, stop, step, dtype=output_type.as_numpy_dtype)
self.assertDatasetProduces(dataset, expected_output=expected_output)
self.assertEqual(output_type, dataset_ops.get_legacy_output_types(dataset))
if __name__ == "__main__":
......
......@@ -870,7 +870,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
return id_dataset.flat_map(flat_map_fn)
@staticmethod
def range(*args):
def range(*args, **kwargs):
"""Creates a `Dataset` of a step-separated range of values.
>>> list(Dataset.range(5).as_numpy_iterator())
......@@ -885,12 +885,18 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
[]
>>> list(Dataset.range(5, 1, -2).as_numpy_iterator())
[5, 3]
>>> list(Dataset.range(2, 5, output_type=tf.int32).as_numpy_iterator())
[2, 3, 4]
>>> list(Dataset.range(1, 5, 2, output_type=tf.float32).as_numpy_iterator())
[1.0, 3.0]
Args:
*args: follows the same semantics as python's xrange.
len(args) == 1 -> start = 0, stop = args[0], step = 1
len(args) == 2 -> start = args[0], stop = args[1], step = 1
len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
**kwargs:
- output_type: Its expected dtype. (Optional, default: `tf.int64`).
Returns:
Dataset: A `RangeDataset`.
......@@ -898,7 +904,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
Raises:
ValueError: if len(args) == 0.
"""
return RangeDataset(*args)
return RangeDataset(*args, **kwargs)
@staticmethod
def zip(datasets):
......@@ -2228,8 +2234,8 @@ class DatasetV1(DatasetV2):
@staticmethod
@functools.wraps(DatasetV2.range)
def range(*args):
return DatasetV1Adapter(DatasetV2.range(*args))
def range(*args, **kwargs):
return DatasetV1Adapter(DatasetV2.range(*args, **kwargs))
@staticmethod
@functools.wraps(DatasetV2.zip)
......@@ -3344,10 +3350,10 @@ class RepeatDataset(UnaryUnchangedStructureDataset):
class RangeDataset(DatasetSource):
"""A `Dataset` of a step separated range of values."""
def __init__(self, *args):
def __init__(self, *args, **kwargs):
"""See `Dataset.range()` for details."""
self._parse_args(*args)
self._structure = tensor_spec.TensorSpec([], dtypes.int64)
self._parse_args(*args, **kwargs)
self._structure = tensor_spec.TensorSpec([], self._output_type)
variant_tensor = gen_dataset_ops.range_dataset(
start=self._start,
stop=self._stop,
......@@ -3355,7 +3361,7 @@ class RangeDataset(DatasetSource):
**self._flat_structure)
super(RangeDataset, self).__init__(variant_tensor)
def _parse_args(self, *args):
def _parse_args(self, *args, **kwargs):
"""Parse arguments according to the same rules as the `range()` builtin."""
if len(args) == 1:
self._start = self._build_tensor(0, "start")
......@@ -3371,6 +3377,10 @@ class RangeDataset(DatasetSource):
self._step = self._build_tensor(args[2], "step")
else:
raise ValueError("Invalid arguments to RangeDataset: %s" % str(args))
if "output_type" in kwargs:
self._output_type = kwargs["output_type"]
else:
self._output_type = dtypes.int64
def _build_tensor(self, int64_value, name):
return ops.convert_to_tensor(int64_value, dtype=dtypes.int64, name=name)
......
......@@ -115,7 +115,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -117,7 +117,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -82,7 +82,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -83,7 +83,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
......@@ -84,7 +84,7 @@ tf_class {
}
member_method {
name: "range"
argspec: "args=[], varargs=args, keywords=None, defaults=None"
argspec: "args=[], varargs=args, keywords=kwargs, defaults=None"
}
member_method {
name: "reduce"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册