提交 98eb7d80 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Adding `tf.data.experimental.cardinality()` which provides...

[tf.data] Adding `tf.data.experimental.cardinality()` which provides information about dataset cardinality.

PiperOrigin-RevId: 224030418
上级 4bc86a6d
op {
graph_op_name: "ExperimentalDatasetCardinality"
visibility: HIDDEN
in_arg {
name: "input_dataset"
description: <<END
A variant tensor representing the dataset to return cardinality for.
END
}
out_arg {
name: "cardinality"
description: <<END
The cardinality of `input_dataset`. Named constants are used to represent
infinite and unknown cardinality.
END
}
summary: "Returns the cardinality of `input_dataset`."
description: <<END
Returns the cardinality of `input_dataset`.
END
}
...@@ -53,6 +53,9 @@ namespace data { ...@@ -53,6 +53,9 @@ namespace data {
// A constant that can be used to enable auto-tuning. // A constant that can be used to enable auto-tuning.
constexpr int kAutoTune = -1; constexpr int kAutoTune = -1;
constexpr int kInfiniteCardinality = -1;
constexpr int kUnknownCardinality = -2;
class DatasetBase; class DatasetBase;
class SerializationContext; class SerializationContext;
...@@ -587,6 +590,9 @@ class DatasetBase : public core::RefCounted { ...@@ -587,6 +590,9 @@ class DatasetBase : public core::RefCounted {
// A human-readable debug string for this dataset. // A human-readable debug string for this dataset.
virtual string DebugString() const = 0; virtual string DebugString() const = 0;
// Returns the cardinality of this dataset.
virtual int64 Cardinality() const { return kUnknownCardinality; }
// Serializes the dataset and writes it to the `writer`. // Serializes the dataset and writes it to the `writer`.
virtual Status Save(SerializationContext* ctx, virtual Status Save(SerializationContext* ctx,
IteratorStateWriter* writer) const; IteratorStateWriter* writer) const;
...@@ -601,7 +607,6 @@ class DatasetBase : public core::RefCounted { ...@@ -601,7 +607,6 @@ class DatasetBase : public core::RefCounted {
const DatasetBase* dataset, Node** output); const DatasetBase* dataset, Node** output);
}; };
// TODO(jsimsa): Consolidate overloading into a single method.
virtual Status AsGraphDefInternal(SerializationContext* ctx, virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
Node** node) const = 0; Node** node) const = 0;
......
...@@ -95,6 +95,15 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -95,6 +95,15 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
return strings::StrCat("BatchDatasetOp(", batch_size_, ")::Dataset"); return strings::StrCat("BatchDatasetOp(", batch_size_, ")::Dataset");
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ +
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -84,6 +84,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ...@@ -84,6 +84,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return "CacheDatasetOp::FileDataset"; return "CacheDatasetOp::FileDataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
...@@ -588,6 +590,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { ...@@ -588,6 +590,8 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
return "CacheDatasetOp::MemoryDataset"; return "CacheDatasetOp::MemoryDataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -79,6 +79,18 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { ...@@ -79,6 +79,18 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
return "ConcatenateDatasetOp::Dataset"; return "ConcatenateDatasetOp::Dataset";
} }
int64 Cardinality() const override {
int64 n1 = input_->Cardinality();
int64 n2 = to_concatenate_->Cardinality();
if (n1 == kInfiniteCardinality || n2 == kInfiniteCardinality) {
return kInfiniteCardinality;
}
if (n1 == kUnknownCardinality || n2 == kUnknownCardinality) {
return kUnknownCardinality;
}
return n1 + n2;
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -46,8 +46,25 @@ class DatasetToGraphOp : public OpKernel { ...@@ -46,8 +46,25 @@ class DatasetToGraphOp : public OpKernel {
} }
}; };
class DatasetCardinalityOp : public OpKernel {
public:
explicit DatasetCardinalityOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
DatasetBase* dataset;
OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
Tensor* result;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &result));
result->scalar<int64>()() = dataset->Cardinality();
}
};
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
DatasetToGraphOp); DatasetToGraphOp);
REGISTER_KERNEL_BUILDER(
Name("ExperimentalDatasetCardinality").Device(DEVICE_CPU),
DatasetCardinalityOp);
} // namespace data } // namespace data
} // namespace tensorflow } // namespace tensorflow
...@@ -76,6 +76,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel { ...@@ -76,6 +76,8 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
return "AssertNextDatasetOp::Dataset"; return "AssertNextDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -114,6 +114,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -114,6 +114,14 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset"); ")::Dataset");
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -60,6 +60,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { ...@@ -60,6 +60,8 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
return "IgnoreErrorsDatasetOp::Dataset"; return "IgnoreErrorsDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -195,6 +195,17 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -195,6 +195,17 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return "MapAndBatchDatasetOp::Dataset"; return "MapAndBatchDatasetOp::Dataset";
} }
// TODO(b/120482302): Note that this is inaccurate until MapDataset is
// modified to preserve cardinality.
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ +
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -75,6 +75,8 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel { ...@@ -75,6 +75,8 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
return errors::Unimplemented(DebugString(), "::AsGraphDefInternal"); return errors::Unimplemented(DebugString(), "::AsGraphDefInternal");
} }
int64 Cardinality() const override { return input_->Cardinality(); }
private: private:
class Iterator : public DatasetIterator<Dataset> { class Iterator : public DatasetIterator<Dataset> {
public: public:
......
...@@ -133,6 +133,15 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -133,6 +133,15 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return "NumaMapAndBatchDatasetOp::Dataset"; return "NumaMapAndBatchDatasetOp::Dataset";
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ +
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -202,6 +202,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { ...@@ -202,6 +202,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
return "ParseExampleDatasetOp::Dataset"; return "ParseExampleDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -76,6 +76,8 @@ class RandomDatasetOp : public DatasetOpKernel { ...@@ -76,6 +76,8 @@ class RandomDatasetOp : public DatasetOpKernel {
")::Dataset"); ")::Dataset");
} }
int64 Cardinality() const override { return kInfiniteCardinality; }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -93,6 +93,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { ...@@ -93,6 +93,10 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "ScanDatasetOp::Dataset"; } string DebugString() const override { return "ScanDatasetOp::Dataset"; }
// TODO(b/120482302): Note that this is inaccurate until MapDataset is
// modified to preserve cardinality.
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -129,6 +129,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { ...@@ -129,6 +129,8 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
return "SetStatsAggregatorDatasetOp::Dataset"; return "SetStatsAggregatorDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -68,6 +68,8 @@ class SleepDatasetOp : public UnaryDatasetOpKernel { ...@@ -68,6 +68,8 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "SleepDatasetOp::Dataset"; } string DebugString() const override { return "SleepDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -103,6 +103,14 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel { ...@@ -103,6 +103,14 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
window_shift_, ", ", window_stride_, ")::Dataset"); window_shift_, ", ", window_stride_, ")::Dataset");
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / window_shift_;
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -78,6 +78,8 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel { ...@@ -78,6 +78,8 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
return "LatencyStatsDatasetOp::Dataset"; return "LatencyStatsDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
...@@ -186,6 +188,8 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel { ...@@ -186,6 +188,8 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
return "BytesProducedStatsDatasetOp::Dataset"; return "BytesProducedStatsDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -169,6 +169,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { ...@@ -169,6 +169,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return "ThreadPoolDatasetOp::Dataset"; return "ThreadPoolDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
...@@ -274,6 +276,8 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel { ...@@ -274,6 +276,8 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
return "MaxIntraOpParallelismDatasetOp::Dataset"; return "MaxIntraOpParallelismDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
...@@ -383,6 +387,8 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel { ...@@ -383,6 +387,8 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return "PrivateThreadPoolDatasetOp::Dataset"; return "PrivateThreadPoolDatasetOp::Dataset";
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -128,6 +128,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel { ...@@ -128,6 +128,10 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "MapDatasetOp::Dataset"; } string DebugString() const override { return "MapDatasetOp::Dataset"; }
// TODO(b/120482302): Note that this is inaccurate until MapDataset is
// modified to preserve cardinality.
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -60,6 +60,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel { ...@@ -60,6 +60,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "ModelDatasetOp::Dataset"; } string DebugString() const override { return "ModelDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -162,6 +162,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { ...@@ -162,6 +162,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "OptimizeDatasetOp::Dataset"; } string DebugString() const override { return "OptimizeDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -152,6 +152,15 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { ...@@ -152,6 +152,15 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
")::Dataset"); ")::Dataset");
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / batch_size_ +
(n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -118,6 +118,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { ...@@ -118,6 +118,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return "ParallelMapDatasetOp::Dataset"; return "ParallelMapDatasetOp::Dataset";
} }
// TODO(b/120482302): Note that this is inaccurate until MapDataset is
// modified to preserve cardinality.
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -56,6 +56,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { ...@@ -56,6 +56,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
string DebugString() const override { return "PrefetchDatasetOp::Dataset"; } string DebugString() const override { return "PrefetchDatasetOp::Dataset"; }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -73,6 +73,14 @@ class RangeDatasetOp : public DatasetOpKernel { ...@@ -73,6 +73,14 @@ class RangeDatasetOp : public DatasetOpKernel {
step_, ")::Dataset"); step_, ")::Dataset");
} }
int64 Cardinality() const override {
if (step_ > 0) {
return std::max(0LL, (stop_ - start_ - 1) / step_ + 1);
} else {
return std::max(0LL, (start_ - stop_ - 1) / -step_ + 1);
}
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -71,6 +71,23 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel { ...@@ -71,6 +71,23 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "RepeatDatasetOp::Dataset"; } string DebugString() const override { return "RepeatDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (count_ < 0) {
if (n == 0) {
return 0;
}
return kInfiniteCardinality;
}
if (count_ == 0) {
return 0;
}
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return count_ * n;
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -61,6 +61,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel { ...@@ -61,6 +61,8 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
return input_->output_shapes(); return input_->output_shapes();
} }
int64 Cardinality() const override { return input_->Cardinality(); }
protected: protected:
template <class T> template <class T>
class Iterator : public DatasetIterator<T> { class Iterator : public DatasetIterator<T> {
......
...@@ -67,6 +67,14 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { ...@@ -67,6 +67,14 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "SkipDatasetOp::Dataset"; } string DebugString() const override { return "SkipDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return std::max(0LL, n - count_);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -54,6 +54,8 @@ class Dataset : public DatasetBase { ...@@ -54,6 +54,8 @@ class Dataset : public DatasetBase {
return "SparseTensorSliceDatasetOp::Dataset"; return "SparseTensorSliceDatasetOp::Dataset";
} }
int64 Cardinality() const override { return sparse_tensor_.shape()[0]; }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -68,6 +68,17 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { ...@@ -68,6 +68,17 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
string DebugString() const override { return "TakeDatasetOp::Dataset"; } string DebugString() const override { return "TakeDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kUnknownCardinality) {
return kUnknownCardinality;
}
if (n == kInfiniteCardinality) {
return count_;
}
return std::min(n, count_);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -61,6 +61,8 @@ class TensorDatasetOp : public DatasetOpKernel { ...@@ -61,6 +61,8 @@ class TensorDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "TensorDatasetOp::Dataset"; } string DebugString() const override { return "TensorDatasetOp::Dataset"; }
int64 Cardinality() const override { return 1LL; }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -84,6 +84,8 @@ class TensorSliceDatasetOp : public DatasetOpKernel { ...@@ -84,6 +84,8 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
return "TensorSliceDatasetOp::Dataset"; return "TensorSliceDatasetOp::Dataset";
} }
int64 Cardinality() const override { return tensors_[0].dim_size(0); }
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -43,6 +43,8 @@ class WindowDataset : public DatasetBase { ...@@ -43,6 +43,8 @@ class WindowDataset : public DatasetBase {
string DebugString() const override { return "WindowDataset"; } string DebugString() const override { return "WindowDataset"; }
int64 Cardinality() const override { return elements_.size(); }
protected: protected:
// TODO(b/110981596): Support checkpointing. // TODO(b/110981596): Support checkpointing.
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
......
...@@ -98,6 +98,15 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { ...@@ -98,6 +98,15 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
window_stride_, drop_remainder_, ")::Dataset"); window_stride_, drop_remainder_, ")::Dataset");
} }
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality || n == kUnknownCardinality) {
return n;
}
return n / window_shift_ +
(n % window_shift_ == 0 || drop_remainder_ ? 0 : 1);
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -76,6 +76,21 @@ class ZipDatasetOp : public DatasetOpKernel { ...@@ -76,6 +76,21 @@ class ZipDatasetOp : public DatasetOpKernel {
string DebugString() const override { return "ZipDatasetOp::Dataset"; } string DebugString() const override { return "ZipDatasetOp::Dataset"; }
int64 Cardinality() const override {
int64 result = kInfiniteCardinality;
for (const auto& input : inputs_) {
int64 n = input->Cardinality();
if (n == kUnknownCardinality) {
return kUnknownCardinality;
}
if (n != kInfiniteCardinality &&
(result == kInfiniteCardinality || n < result)) {
result = n;
}
}
return result;
}
protected: protected:
Status AsGraphDefInternal(SerializationContext* ctx, Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b, DatasetGraphDefBuilder* b,
......
...@@ -71,6 +71,11 @@ REGISTER_OP("ExperimentalCSVDataset") ...@@ -71,6 +71,11 @@ REGISTER_OP("ExperimentalCSVDataset")
return shape_inference::ScalarShape(c); return shape_inference::ScalarShape(c);
}); });
REGISTER_OP("ExperimentalDatasetCardinality")
.Input("input_dataset: variant")
.Output("cardinality: int64")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ExperimentalDatasetToTFRecord") REGISTER_OP("ExperimentalDatasetToTFRecord")
.Input("input_dataset: variant") .Input("input_dataset: variant")
.Input("filename: string") .Input("filename: string")
......
...@@ -35,8 +35,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. ...@@ -35,8 +35,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@TFRecordWriter @@TFRecordWriter
@@ThreadingOptions @@ThreadingOptions
@@assume_finite
@@bucket_by_sequence_length @@bucket_by_sequence_length
@@cardinality
@@choose_from_datasets @@choose_from_datasets
@@copy_to_device @@copy_to_device
@@dense_to_sparse_batch @@dense_to_sparse_batch
...@@ -63,6 +63,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview. ...@@ -63,6 +63,8 @@ See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
@@unique @@unique
@@AUTOTUNE @@AUTOTUNE
@@INFINITE_CARDINALITY
@@UNKNOWN_CARDINALITY
""" """
from __future__ import absolute_import from __future__ import absolute_import
...@@ -74,6 +76,9 @@ from __future__ import print_function ...@@ -74,6 +76,9 @@ from __future__ import print_function
from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
from tensorflow.python.data.experimental.ops.batching import map_and_batch from tensorflow.python.data.experimental.ops.batching import map_and_batch
from tensorflow.python.data.experimental.ops.batching import unbatch from tensorflow.python.data.experimental.ops.batching import unbatch
from tensorflow.python.data.experimental.ops.cardinality import cardinality
from tensorflow.python.data.experimental.ops.cardinality import INFINITE as INFINITE_CARDINALITY
from tensorflow.python.data.experimental.ops.cardinality import UNKNOWN as UNKNOWN_CARDINALITY
from tensorflow.python.data.experimental.ops.counter import Counter from tensorflow.python.data.experimental.ops.counter import Counter
from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
from tensorflow.python.data.experimental.ops.error_ops import ignore_errors from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
...@@ -83,7 +88,6 @@ from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_ ...@@ -83,7 +88,6 @@ from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_
from tensorflow.python.data.experimental.ops.grouping import group_by_reducer from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
from tensorflow.python.data.experimental.ops.grouping import group_by_window from tensorflow.python.data.experimental.ops.grouping import group_by_window
from tensorflow.python.data.experimental.ops.grouping import Reducer from tensorflow.python.data.experimental.ops.grouping import Reducer
from tensorflow.python.data.experimental.ops.has_indefinite_repeat import assume_finite
from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
......
...@@ -211,18 +211,6 @@ py_test( ...@@ -211,18 +211,6 @@ py_test(
], ],
) )
py_test(
name = "has_indefinite_repeat_test",
srcs = ["has_indefinite_repeat_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/data/experimental/ops:has_indefinite_repeat",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
py_test( py_test(
name = "ignore_errors_test", name = "ignore_errors_test",
srcs = ["ignore_errors_test.py"], srcs = ["ignore_errors_test.py"],
...@@ -377,6 +365,18 @@ py_test( ...@@ -377,6 +365,18 @@ py_test(
], ],
) )
py_test(
name = "cardinality_test",
srcs = ["cardinality_test.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python/data/experimental/ops:cardinality",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
)
py_test( py_test(
name = "override_threadpool_test", name = "override_threadpool_test",
size = "small", size = "small",
......
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.experimental.cardinality()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import cardinality
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
"""Tests for `tf.data.experimental.cardinality()`."""
@parameterized.named_parameters(
# pylint: disable=g-long-lambda
("Batch1",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
("Batch2",
lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=False), 3),
("Batch3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).batch(2),
cardinality.UNKNOWN),
("Batch4", lambda: dataset_ops.Dataset.range(5).repeat().batch(2),
cardinality.INFINITE),
("Cache1", lambda: dataset_ops.Dataset.range(5).cache(), 5),
("Cache2", lambda: dataset_ops.Dataset.range(5).cache("foo"), 5),
("Concatenate1", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5)), 10),
("Concatenate2",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5)), cardinality.UNKNOWN),
("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5)),
cardinality.INFINITE),
("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)),
cardinality.UNKNOWN),
("Concatenate5",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).filter(lambda _: True)),
cardinality.UNKNOWN),
("Concatenate6", lambda: dataset_ops.Dataset.range(5).repeat().
concatenate(dataset_ops.Dataset.range(5).filter(lambda _: True)),
cardinality.INFINITE),
("Concatenate7", lambda: dataset_ops.Dataset.range(5).concatenate(
dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
("Concatenate8",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
("Concatenate9",
lambda: dataset_ops.Dataset.range(5).repeat().concatenate(
dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map(
lambda _: dataset_ops.Dataset.from_tensors(0)),
cardinality.UNKNOWN),
("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True),
cardinality.UNKNOWN),
("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1),
("FromTensors2", lambda: dataset_ops.Dataset.from_tensors((0, 1)), 1),
("FromTensorSlices1",
lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0]), 3),
("FromTensorSlices2",
lambda: dataset_ops.Dataset.from_tensor_slices(([0, 0, 0], [1, 1, 1])),
3),
("Interleave1", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
cardinality.UNKNOWN),
("Interleave2", lambda: dataset_ops.Dataset.range(5).interleave(
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), cardinality.UNKNOWN),
("Map1", lambda: dataset_ops.Dataset.range(5).map(lambda x: x), 5),
("Map2", lambda: dataset_ops.Dataset.range(5).map(
lambda x: x, num_parallel_calls=1), 5),
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=True), 2),
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=False), 3),
("PaddedBatch3", lambda: dataset_ops.Dataset.range(5).filter(
lambda _: True).padded_batch(2, []), cardinality.UNKNOWN),
("PaddedBatch4",
lambda: dataset_ops.Dataset.range(5).repeat().padded_batch(2, []),
cardinality.INFINITE),
("Prefetch", lambda: dataset_ops.Dataset.range(5).prefetch(buffer_size=1),
5),
("Range1", lambda: dataset_ops.Dataset.range(0), 0),
("Range2", lambda: dataset_ops.Dataset.range(5), 5),
("Range3", lambda: dataset_ops.Dataset.range(5, 10), 5),
("Range4", lambda: dataset_ops.Dataset.range(10, 5), 0),
("Range5", lambda: dataset_ops.Dataset.range(5, 10, 2), 3),
("Range6", lambda: dataset_ops.Dataset.range(10, 5, -2), 3),
("Repeat1", lambda: dataset_ops.Dataset.range(0).repeat(0), 0),
("Repeat2", lambda: dataset_ops.Dataset.range(1).repeat(0), 0),
("Repeat3", lambda: dataset_ops.Dataset.range(0).repeat(5), 0),
("Repeat4", lambda: dataset_ops.Dataset.range(1).repeat(5), 5),
("Repeat5", lambda: dataset_ops.Dataset.range(0).repeat(), 0),
("Repeat6", lambda: dataset_ops.Dataset.range(1).repeat(),
cardinality.INFINITE),
("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1),
5),
("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3),
("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0),
("Skip3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).skip(2),
cardinality.UNKNOWN),
("Skip4", lambda: dataset_ops.Dataset.range(5).repeat().skip(2),
cardinality.INFINITE),
("Take1", lambda: dataset_ops.Dataset.range(5).take(2), 2),
("Take2", lambda: dataset_ops.Dataset.range(5).take(8), 5),
("Take3",
lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).take(2),
cardinality.UNKNOWN),
("Take4", lambda: dataset_ops.Dataset.range(5).repeat().take(2), 2),
("Window1", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=True), 2),
("Window2", lambda: dataset_ops.Dataset.range(5).window(
size=2, shift=2, drop_remainder=False), 3),
("Zip1", lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5)),
5),
("Zip2", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3),
("Zip3", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5),
dataset_ops.Dataset.range(3).repeat())), 5),
("Zip4", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5).repeat(),
dataset_ops.Dataset.range(3).repeat())), cardinality.INFINITE),
("Zip5", lambda: dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(5),
dataset_ops.Dataset.range(3).filter(lambda _: True))),
cardinality.UNKNOWN),
# pylint: enable=g-long-lambda
)
def testNumElements(self, dataset_fn, expected_result):
with self.cached_session() as sess:
self.assertEqual(
sess.run(cardinality.cardinality(dataset_fn())), expected_result)
if __name__ == "__main__":
test.main()
...@@ -4,6 +4,16 @@ licenses(["notice"]) # Apache 2.0 ...@@ -4,6 +4,16 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"]) exports_files(["LICENSE"])
py_library(
name = "cardinality",
srcs = ["cardinality.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:tensor_util",
],
)
py_library( py_library(
name = "counter", name = "counter",
srcs = ["counter.py"], srcs = ["counter.py"],
...@@ -28,16 +38,6 @@ py_library( ...@@ -28,16 +38,6 @@ py_library(
], ],
) )
py_library(
name = "has_indefinite_repeat",
srcs = ["has_indefinite_repeat.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:tensor_util",
"//tensorflow/python/data/ops:dataset_ops",
],
)
py_library( py_library(
name = "iterator_ops", name = "iterator_ops",
srcs = [ srcs = [
...@@ -433,13 +433,13 @@ py_library( ...@@ -433,13 +433,13 @@ py_library(
name = "dataset_ops", name = "dataset_ops",
deps = [ deps = [
":batching", ":batching",
":cardinality",
":counter", ":counter",
":enumerate_ops", ":enumerate_ops",
":error_ops", ":error_ops",
":filter_for_shard_ops", ":filter_for_shard_ops",
":get_single_element", ":get_single_element",
":grouping", ":grouping",
":has_indefinite_repeat",
":indexed_dataset_ops", ":indexed_dataset_ops",
":interleave_ops", ":interleave_ops",
":map_defun", ":map_defun",
......
...@@ -12,42 +12,39 @@ ...@@ -12,42 +12,39 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Tests for `tf.data.experimental.has_indefinite_repeat()`.""" """Cardinality analysis of `Dataset` objects."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from absl.testing import parameterized from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.util.tf_export import tf_export
from tensorflow.python.data.experimental.ops import has_indefinite_repeat
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops INFINITE = -1
from tensorflow.python.platform import test UNKNOWN = -2
tf_export("data.experimental.INFINITE_CARDINALITY").export_constant(
__name__, "INFINITE")
class HasIndefiniteRepeat(test_base.DatasetTestBase, parameterized.TestCase): tf_export("data.experimental.UNKNOWN_CARDINALITY").export_constant(
"""Tests for `tf.data.experimental.has_indefinite_repeat()`.""" __name__, "UNKNOWN")
@parameterized.named_parameters(
("NoRepeat", dataset_ops.Dataset.range(10), False), @tf_export("data.experimental.cardinality")
("FiniteRepeat", dataset_ops.Dataset.range(10).repeat(2), False), def cardinality(dataset):
("FiniteRepeatNotAtEnd", dataset_ops.Dataset.range(10).repeat(2).skip(1), """Returns the cardinality of `dataset`, if known.
False),
("InfiniteRepeat", dataset_ops.Dataset.range(10).repeat(), True), The operation returns the cardinality of `dataset`. The operation may return
("InfiniteRepeatNotAtEnd", dataset_ops.Dataset.range(10).repeat().skip(1), `tf.data.experimental.INFINITE_CARDINALITY` if `dataset` contains an infinite
True), number of elements or `tf.data.experimental.UNKNOWN_CARDINALITY` if the
("InfiniteRepeatThenFiniteRepeat", analysis fails to determine the number of elements in `dataset` (e.g. when the
dataset_ops.Dataset.range(10).repeat().repeat(2), True), dataset source is a file).
("ConcatenateFiniteAndInfinite",
dataset_ops.Dataset.range(10).repeat(2).concatenate( Args:
dataset_ops.Dataset.range(10).repeat()), True), dataset: A `tf.data.Dataset` for which to determine cardinality.
("AssumeFinite", dataset_ops.Dataset.range(10).repeat().apply(
has_indefinite_repeat.assume_finite()), False), Returns:
) A scalar `tf.int64` `Tensor` representing the cardinality of `dataset`. If
def testHasIndefiniteRepeat(self, dataset, expected_result): the cardinality is infinite or unknown, the operation returns the named
self.assertEqual( constant `INFINITE_CARDINALITY` and `UNKNOWN_CARDINALITY` respectively.
has_indefinite_repeat.has_indefinite_repeat(dataset), expected_result) """
return ged_ops.experimental_dataset_cardinality(dataset._as_variant_tensor()) # pylint: disable=protected-access
if __name__ == "__main__":
test.main()
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Finiteness analysis of `Dataset` objects."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.util.tf_export import tf_export
class _AssumeFiniteDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset):
super(_AssumeFiniteDataset, self).__init__(input_dataset)
self._input_dataset = input_dataset
def _as_variant_tensor(self):
return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
@tf_export("data.experimental.assume_finite")
def assume_finite():
"""Assume that the input is finite, even if it contains `Dataset.repeat()`.
Training libraries may analyze a `tf.data.Dataset` to determine if it is
finite or infinite (e.g. because it contains an indefinite
`tf.data.Dataset.repeat` transformation). Since that analysis may be
imprecise, this transformation allows the user to annotate a dataset
explicitly as being finite.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _AssumeFiniteDataset(dataset)
return _apply_fn
def has_indefinite_repeat(dataset):
"""Returns `True` if `dataset` or any of its inputs is `Dataset.repeat()`.
NOTE: For simplicity, this analysis does not attempt to analyze nested
datasets (e.g. in a function passed to `tf.data.Dataset.flat_map`). If the
analysis is incorrect, you can apply `tf.data.experimental.assume_finite()`
to the dataset to override it.
Args:
dataset: A `tf.data.Dataset`.
Returns:
`True` if `dataset` or any of its inputs is repeated indefinitely.
"""
# pylint: disable=protected-access
if isinstance(dataset, dataset_ops.DatasetV1Adapter):
return has_indefinite_repeat(dataset._dataset)
elif isinstance(dataset, dataset_ops.RepeatDataset):
count = tensor_util.constant_value(dataset._count)
return count == -1 or has_indefinite_repeat(dataset._inputs()[0])
elif isinstance(dataset, _AssumeFiniteDataset):
return False
else:
return any(
has_indefinite_repeat(input_dataset)
for input_dataset in dataset._inputs())
...@@ -12,6 +12,10 @@ tf_module { ...@@ -12,6 +12,10 @@ tf_module {
name: "CsvDataset" name: "CsvDataset"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member { member {
name: "OptimizationOptions" name: "OptimizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
...@@ -48,18 +52,22 @@ tf_module { ...@@ -48,18 +52,22 @@ tf_module {
name: "ThreadingOptions" name: "ThreadingOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "UNKNOWN_CARDINALITY"
mtype: "<type \'int\'>"
}
member_method { member_method {
name: "Counter" name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], " argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
} }
member_method {
name: "assume_finite"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "bucket_by_sequence_length" name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], " argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
} }
member_method {
name: "cardinality"
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "choose_from_datasets" name: "choose_from_datasets"
argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
......
...@@ -12,6 +12,10 @@ tf_module { ...@@ -12,6 +12,10 @@ tf_module {
name: "CsvDataset" name: "CsvDataset"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "INFINITE_CARDINALITY"
mtype: "<type \'int\'>"
}
member { member {
name: "OptimizationOptions" name: "OptimizationOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
...@@ -48,18 +52,22 @@ tf_module { ...@@ -48,18 +52,22 @@ tf_module {
name: "ThreadingOptions" name: "ThreadingOptions"
mtype: "<type \'type\'>" mtype: "<type \'type\'>"
} }
member {
name: "UNKNOWN_CARDINALITY"
mtype: "<type \'int\'>"
}
member_method { member_method {
name: "Counter" name: "Counter"
argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], " argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
} }
member_method {
name: "assume_finite"
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "bucket_by_sequence_length" name: "bucket_by_sequence_length"
argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], " argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
} }
member_method {
name: "cardinality"
argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
}
member_method { member_method {
name: "choose_from_datasets" name: "choose_from_datasets"
argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None" argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册