提交 0088fbb3 编写于 作者: F Frank Chen 提交者: TensorFlower Gardener

Add support for overriding snapshot operation modes, graph hashes and run ids.

PiperOrigin-RevId: 286281695
Change-Id: I021eec024331e65b2d6243e14bb477957ce0e2f6
上级 ce9c895d
......@@ -71,6 +71,11 @@ const int64 kSnappyReaderOutputBufferSizeBytes = 16 << 20; // 16 MiB
const size_t kHeaderSize = sizeof(uint64);
constexpr char kModeAuto[] = "auto";
constexpr char kModeWrite[] = "write";
constexpr char kModeRead[] = "read";
constexpr char kModePassthrough[] = "passthrough";
constexpr char kSnapshotFilename[] = "snapshot.metadata";
constexpr char kSnapshotReaderWorkerPool[] = "snapshot_reader_worker_pool";
constexpr char kSnapshotWriterWorkerPool[] = "snapshot_writer_worker_pool";
......@@ -304,10 +309,29 @@ Status DumpDatasetGraph(const std::string& path, uint64 hash,
return WriteTextProto(Env::Default(), graph_file, graph);
}
Status DetermineOpState(const Status& file_status,
Status DetermineOpState(const std::string& mode_string,
const Status& file_status,
const experimental::SnapshotMetadataRecord& metadata,
const uint64 pending_snapshot_expiry_seconds,
SnapshotMode* mode) {
if (mode_string == kModeRead) {
LOG(INFO) << "Overriding mode to reader.";
*mode = READER;
return Status::OK();
}
if (mode_string == kModeWrite) {
LOG(INFO) << "Overriding mode to writer.";
*mode = WRITER;
return Status::OK();
}
if (mode_string == kModePassthrough) {
LOG(INFO) << "Overriding mode to passthrough.";
*mode = PASSTHROUGH;
return Status::OK();
}
if (errors::IsNotFound(file_status)) {
*mode = WRITER;
return Status::OK();
......@@ -365,6 +389,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed", &seed_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("seed2", &seed2_));
mode_ = kModeAuto;
if (ctx->HasAttr("mode")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_));
}
snapshot_name_ = "";
if (ctx->HasAttr("snapshot_name")) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("snapshot_name", &snapshot_name_));
}
if (shard_size_bytes_ == -1) shard_size_bytes_ = kDefaultShardSizeBytes;
// Default to 1 day expiry for snapshots.
......@@ -389,6 +423,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
ctx, pending_snapshot_expiry_seconds_ >= 1,
errors::InvalidArgument(
"pending_snapshot_expiry_seconds must be at least 1 second."));
OP_REQUIRES(ctx,
mode_ == kModeAuto || mode_ == kModeRead ||
mode_ == kModeWrite || mode_ == kModePassthrough,
errors::InvalidArgument("mode must be either '", kModeAuto,
"', '", kModeRead, "', '", kModeWrite,
"', or '", kModePassthrough, "'."));
}
protected:
......@@ -417,15 +458,16 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
<< dump_status.ToString();
}
LOG(INFO) << "Graph def serialized to hash: " << hash;
std::string graph_hash =
strings::StrCat(strings::Hex(hash, strings::kZeroPad16));
LOG(INFO) << "Graph def serialized to hash: " << graph_hash;
*output = new Dataset(
ctx, input, path,
strings::StrCat(strings::Hex(hash, strings::kZeroPad16)),
reader_path_prefix_, writer_path_prefix_, compression_,
shard_size_bytes_, pending_snapshot_expiry_seconds_,
num_reader_threads_, reader_buffer_size_, num_writer_threads_,
writer_buffer_size_, shuffle_on_read_, seed_, seed2_);
*output = new Dataset(ctx, input, path, graph_hash, reader_path_prefix_,
writer_path_prefix_, compression_, shard_size_bytes_,
pending_snapshot_expiry_seconds_, num_reader_threads_,
reader_buffer_size_, num_writer_threads_,
writer_buffer_size_, shuffle_on_read_, seed_, seed2_,
mode_, snapshot_name_);
}
private:
......@@ -438,7 +480,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
const uint64 pending_snapshot_expiry_seconds,
const uint64 num_reader_threads, const uint64 reader_buffer_size,
const uint64 num_writer_threads, const uint64 writer_buffer_size,
const bool shuffle_on_read, const uint64 seed, const uint64 seed2)
const bool shuffle_on_read, const uint64 seed, const uint64 seed2,
const std::string& mode, const std::string& snapshot_name)
: DatasetBase(DatasetContext(ctx)),
input_(input),
dir_(path),
......@@ -454,7 +497,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
writer_buffer_size_(writer_buffer_size),
shuffle_on_read_(shuffle_on_read),
seed_(seed),
seed2_(seed2) {
seed2_(seed2),
mode_(mode),
snapshot_name_(snapshot_name) {
input_->Ref();
}
......@@ -529,6 +574,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
AttrValue seed2_attr;
b->BuildAttrValue<int64>(seed2_, &seed2_attr);
AttrValue mode_attr;
b->BuildAttrValue(mode_, &mode_attr);
AttrValue snapshot_name_attr;
b->BuildAttrValue(snapshot_name_, &snapshot_name_attr);
TF_RETURN_IF_ERROR(b->AddDataset(
this,
/*inputs=*/
......@@ -548,7 +599,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
{"writer_buffer_size", writer_buffer_size_attr},
{"shuffle_on_read", shuffle_on_read_attr},
{"seed", seed_attr},
{"seed2", seed2_attr}},
{"seed2", seed2_attr},
{"mode", mode_attr},
{"snapshot_name", snapshot_name_attr}},
output));
return Status::OK();
}
......@@ -558,7 +611,13 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {
hash_dir_ = io::JoinPath(dataset()->dir_, dataset()->graph_hash_);
if (dataset()->snapshot_name_.empty()) {
hash_dir_ = io::JoinPath(dataset()->dir_, dataset()->graph_hash_);
} else {
hash_dir_ = io::JoinPath(
dataset()->dir_,
strings::StrCat("custom-", dataset()->snapshot_name_));
}
}
// We have a somewhat non traditional pattern for iterator initialization
......@@ -581,8 +640,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
experimental::SnapshotMetadataRecord metadata;
Status s = ReadMetadataFile(hash_dir_, &metadata);
TF_RETURN_IF_ERROR(DetermineOpState(
s, metadata, dataset()->pending_snapshot_expiry_seconds_,
&state_));
dataset()->mode_, s, metadata,
dataset()->pending_snapshot_expiry_seconds_, &state_));
TF_RETURN_IF_ERROR(InitializeIterator(ctx, metadata));
}
return iterator_->GetNext(ctx, out_tensors, end_of_sequence);
......@@ -626,18 +685,32 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
IteratorContext* ctx,
const experimental::SnapshotMetadataRecord& metadata)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::string run_id = "";
if (!dataset()->snapshot_name_.empty()) {
// We have overridden the snapshot with a custom name, so we don't
// generate random run ids, but just use the same one.
run_id = "custom";
}
switch (state_) {
case WRITER:
iterator_ = absl::make_unique<SnapshotWriterIterator>(
SnapshotWriterIterator::Params{
dataset(), absl::StrCat(prefix(), "WriterImpl")},
hash_dir_);
hash_dir_, run_id);
break;
case READER:
if (run_id.empty() && metadata.run_id().empty()) {
return errors::NotFound(
"Could not find a valid snapshot to read.");
}
if (run_id.empty()) {
run_id = metadata.run_id();
}
iterator_ = absl::make_unique<SnapshotReaderIterator>(
SnapshotReaderIterator::Params{
dataset(), absl::StrCat(prefix(), "ReaderImpl")},
hash_dir_, metadata);
hash_dir_, run_id);
break;
case PASSTHROUGH:
iterator_ = absl::make_unique<SnapshotPassthroughIterator>(
......@@ -653,12 +726,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
public:
static constexpr const char* const kParse = "Parse";
explicit SnapshotReaderIterator(
const Params& params, const string& hash_dir,
const experimental::SnapshotMetadataRecord& metadata)
explicit SnapshotReaderIterator(const Params& params,
const string& hash_dir,
const string& run_id)
: DatasetIterator<Dataset>(params),
hash_dir_(hash_dir),
metadata_(metadata) {}
run_id_(run_id) {}
~SnapshotReaderIterator() override {
mutex_lock l(mu_);
......@@ -673,7 +746,6 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
thread_pool_ = ctx->CreateThreadPool(kSnapshotReaderWorkerPool,
dataset()->num_reader_threads_);
run_id_ = metadata_.run_id();
run_dir_ = io::JoinPath(hash_dir_, run_id_);
// Get all the files in the run_dir.
std::vector<std::string> filenames_str;
......@@ -683,8 +755,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
std::copy(filenames_str.begin(), filenames_str.end(),
filenames_.begin());
if (filenames_.empty()) {
return errors::InvalidArgument("Could not find any files in dir: ",
run_dir_);
return errors::NotFound("Could not find any files in dir: ",
run_dir_);
}
if (dataset()->shuffle_on_read_) {
......@@ -1071,8 +1143,11 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
"ProcessOneElement";
explicit SnapshotWriterIterator(const Params& params,
const string& hash_dir)
: DatasetIterator<Dataset>(params), hash_dir_(hash_dir) {}
const string& hash_dir,
const string& run_id)
: DatasetIterator<Dataset>(params),
hash_dir_(hash_dir),
run_id_(run_id) {}
~SnapshotWriterIterator() override {
mutex_lock l(mu_);
......@@ -1087,8 +1162,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
thread_pool_ = ctx->CreateThreadPool(kSnapshotWriterWorkerPool,
dataset()->num_writer_threads_);
run_id_ = strings::StrCat(
strings::Hex(random::New64(), strings::kZeroPad4));
if (run_id_.empty()) {
run_id_ = strings::StrCat(
strings::Hex(random::New64(), strings::kZeroPad4));
}
run_dir_ =
io::JoinPath(dataset()->writer_path_prefix_, hash_dir_, run_id_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
......@@ -1619,6 +1696,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
const uint64 seed_;
const uint64 seed2_;
const std::string mode_;
const std::string snapshot_name_;
};
const int graph_def_version_;
......@@ -1639,6 +1719,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
int64 seed_;
int64 seed2_;
std::string mode_;
std::string snapshot_name_;
};
REGISTER_KERNEL_BUILDER(Name("SnapshotDataset").Device(DEVICE_CPU),
......
......@@ -823,6 +823,8 @@ REGISTER_OP("SnapshotDataset")
.Attr("shuffle_on_read: bool = false")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("mode: string = 'auto'")
.Attr("snapshot_name: string = ''")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
// snapshot_path should be a scalar.
......
......@@ -18,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import os
import shutil
import time
from absl.testing import parameterized
......@@ -27,6 +28,7 @@ from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
......@@ -164,6 +166,93 @@ class SnapshotDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testSpecifySnapshotNameWriteAndRead(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.snapshot(tmpdir, snapshot_name="my_custom_snapshot"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
self.assertTrue(
os.path.exists(os.path.join(tmpdir, "custom-my_custom_snapshot")))
self.assertTrue(
os.path.exists(
os.path.join(tmpdir, "custom-my_custom_snapshot", "custom")))
@combinations.generate(test_base.default_test_combinations())
def testForcePassthroughMode(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="passthrough"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
self.assertSnapshotDirectoryContains(tmpdir, 0, 0, 0)
@combinations.generate(test_base.default_test_combinations())
def testForceWriteMode(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="write"))
dataset = dataset.repeat(10)
self.assertDatasetProduces(dataset, list(range(10)) * 10)
# We will end up writing 10 different runs.
self.assertSnapshotDirectoryContains(tmpdir, 1, 10, 1)
@combinations.generate(test_base.default_test_combinations())
def testForceReadMode(self):
tmpdir = self.makeSnapshotDirectory()
# We write a copy of the snapshot first.
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.snapshot(
tmpdir, mode="write", snapshot_name="my_custom_snapshot"))
self.assertDatasetProduces(dataset, list(range(10)))
# We move the run to a new name.
shutil.move(
os.path.join(tmpdir, "custom-my_custom_snapshot"),
os.path.join(tmpdir, "custom-my_custom_snapshot_2"))
# Even though the snapshot.metadata is pointing to the old run that no
# longer exists after we moved, we force it to read from the run we specify.
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
snapshot.snapshot(
tmpdir, mode="read", snapshot_name="my_custom_snapshot_2"))
self.assertDatasetProduces(dataset, list(range(10)))
# We should still have one snapshot and one run.
self.assertSnapshotDirectoryContains(tmpdir, 1, 1, 1)
@combinations.generate(test_base.default_test_combinations())
def testForceReadNonexistentSnapshot(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
with self.assertRaises(errors.NotFoundError):
dataset = dataset.apply(snapshot.snapshot(tmpdir, mode="read"))
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testForceReadNonexistentNamedSnapshot(self):
tmpdir = self.makeSnapshotDirectory()
dataset = dataset_ops.Dataset.range(10)
with self.assertRaises(errors.NotFoundError):
dataset = dataset.apply(
snapshot.snapshot(
tmpdir, mode="read", snapshot_name="my_nonexistent_snapshot"))
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
......
......@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.compat import compat
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -45,7 +46,9 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
num_writer_threads=None,
writer_buffer_size=None,
shuffle_on_read=None,
seed=None):
seed=None,
mode=None,
snapshot_name=None):
self._compression = compression if compression is not None else ""
self._reader_path_prefix = (
......@@ -67,28 +70,51 @@ class _SnapshotDataset(dataset_ops.UnaryUnchangedStructureDataset):
writer_buffer_size if writer_buffer_size is not None else -1)
self._shuffle_on_read = (
shuffle_on_read if shuffle_on_read is not None else False)
self._mode = (mode if mode is not None else "auto")
self._snapshot_name = (snapshot_name if snapshot_name is not None else "")
self._seed, self._seed2 = random_seed.get_seed(seed)
self._input_dataset = input_dataset
self._path = ops.convert_to_tensor(path, dtype=dtypes.string, name="path")
variant_tensor = ged_ops.snapshot_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
path=self._path,
compression=self._compression,
reader_path_prefix=self._reader_path_prefix,
writer_path_prefix=self._writer_path_prefix,
shard_size_bytes=self._shard_size_bytes,
pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
num_reader_threads=self._num_reader_threads,
reader_buffer_size=self._reader_buffer_size,
num_writer_threads=self._num_writer_threads,
writer_buffer_size=self._writer_buffer_size,
shuffle_on_read=self._shuffle_on_read,
seed=self._seed,
seed2=self._seed2,
**self._flat_structure)
if compat.forward_compatible(2020, 1, 10) or mode or snapshot_name:
variant_tensor = ged_ops.snapshot_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
path=self._path,
compression=self._compression,
reader_path_prefix=self._reader_path_prefix,
writer_path_prefix=self._writer_path_prefix,
shard_size_bytes=self._shard_size_bytes,
pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
num_reader_threads=self._num_reader_threads,
reader_buffer_size=self._reader_buffer_size,
num_writer_threads=self._num_writer_threads,
writer_buffer_size=self._writer_buffer_size,
shuffle_on_read=self._shuffle_on_read,
seed=self._seed,
seed2=self._seed2,
mode=self._mode,
snapshot_name=self._snapshot_name,
**self._flat_structure)
else:
variant_tensor = ged_ops.snapshot_dataset(
self._input_dataset._variant_tensor, # pylint: disable=protected-access
path=self._path,
compression=self._compression,
reader_path_prefix=self._reader_path_prefix,
writer_path_prefix=self._writer_path_prefix,
shard_size_bytes=self._shard_size_bytes,
pending_snapshot_expiry_seconds=self._pending_snapshot_expiry_seconds,
num_reader_threads=self._num_reader_threads,
reader_buffer_size=self._reader_buffer_size,
num_writer_threads=self._num_writer_threads,
writer_buffer_size=self._writer_buffer_size,
shuffle_on_read=self._shuffle_on_read,
seed=self._seed,
seed2=self._seed2,
**self._flat_structure)
super(_SnapshotDataset, self).__init__(input_dataset, variant_tensor)
......@@ -103,7 +129,9 @@ def snapshot(path,
num_writer_threads=None,
writer_buffer_size=None,
shuffle_on_read=None,
seed=None):
seed=None,
mode=None,
snapshot_name=None):
"""Writes to/reads from a snapshot of a dataset.
This function attempts to determine whether a valid snapshot exists at the
......@@ -122,39 +150,56 @@ def snapshot(path,
Defaults to None.
shard_size_bytes: The size of each shard to be written by the snapshot
dataset op. Defaults to 10 GiB.
pending_snapshot_expiry_seconds: How long to wait (in seconds) before
the snapshot op considers a previously unfinished snapshot to be stale.
pending_snapshot_expiry_seconds: How long to wait (in seconds) before the
snapshot op considers a previously unfinished snapshot to be stale.
num_reader_threads: Number of threads to parallelize reading from snapshot.
Especially useful if compression is turned on since the decompression
operation tends to be intensive. Defaults to 1. If > 1, then this might
introduce non-determinism i.e. the order in which the elements are
read from the snapshot are different from the order they're written.
introduce non-determinism i.e. the order in which the elements are read
from the snapshot are different from the order they're written.
reader_buffer_size: Maximum number of elements we can prefetch reading from
the snapshot. Defaults to 1. Increasing this might improve performance
but will increase memory consumption.
the snapshot. Defaults to 1. Increasing this might improve performance but
will increase memory consumption.
num_writer_threads: Number of threads to parallelize writing from snapshot.
We'll open up `num_writer_threads` files and write to them in parallel.
Especially useful if compression is turned on since the compression
operation tends to be intensive. Defaults to 1. If > 1, then this might
introduce non-determinism i.e. the order in which the elements are
read from the upstream iterator are different from the order they're
written.
introduce non-determinism i.e. the order in which the elements are read
from the upstream iterator are different from the order they're written.
writer_buffer_size: Maximum number of pipeline elements to fill up the
buffer before writing them out using `num_writer_threads`.
shuffle_on_read: If this is True, then the order in which examples are
produced when reading from a snapshot will be random. Defaults to False.
seed: If seed is set, the random number generator is seeded by the given
seed. Otherwise, it is seeded by a random seed.
mode: The mode at which snapshot should operate. Valid options are "auto",
"read", "write", and "passthrough". The default mode is "auto", where the
snapshot op will automatically determine what mode to operate in.
snapshot_name: If set, use the supplied string as a named snapshot name
instead of introspecting the data pipeline and automatically generating a
unique identifier for the snapshot.
Returns:
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
return _SnapshotDataset(dataset, path, compression, reader_path_prefix,
writer_path_prefix, shard_size_bytes,
pending_snapshot_expiry_seconds, num_reader_threads,
reader_buffer_size, num_writer_threads,
writer_buffer_size, shuffle_on_read, seed)
return _SnapshotDataset(
input_dataset=dataset,
path=path,
compression=compression,
reader_path_prefix=reader_path_prefix,
writer_path_prefix=writer_path_prefix,
shard_size_bytes=shard_size_bytes,
pending_snapshot_expiry_seconds=pending_snapshot_expiry_seconds,
num_reader_threads=num_reader_threads,
reader_buffer_size=reader_buffer_size,
num_writer_threads=num_writer_threads,
writer_buffer_size=writer_buffer_size,
shuffle_on_read=shuffle_on_read,
seed=seed,
mode=mode,
snapshot_name=snapshot_name)
return _apply_fn
......@@ -3834,7 +3834,7 @@ tf_module {
}
member_method {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
}
member_method {
name: "Softmax"
......
......@@ -3834,7 +3834,7 @@ tf_module {
}
member_method {
name: "SnapshotDataset"
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'None\'], "
argspec: "args=[\'input_dataset\', \'path\', \'output_types\', \'output_shapes\', \'compression\', \'reader_path_prefix\', \'writer_path_prefix\', \'shard_size_bytes\', \'pending_snapshot_expiry_seconds\', \'num_reader_threads\', \'reader_buffer_size\', \'num_writer_threads\', \'writer_buffer_size\', \'shuffle_on_read\', \'seed\', \'seed2\', \'mode\', \'snapshot_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'10737418240\', \'86400\', \'1\', \'1\', \'1\', \'1\', \'False\', \'0\', \'0\', \'auto\', \'\', \'None\'], "
}
member_method {
name: "Softmax"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册