提交 e13cd053 编写于 作者: M Martin Wicke 提交者: GitHub

Merge pull request #2891 from mrry/r0.9-cherrypick

Cherry-picking stability and doc fixes for r0.9
......@@ -569,11 +569,15 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
bool success =
cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
if (!success) {
return errors::Cancelled("Step was cancelled");
calls.StartCancel();
}
calls.Wait();
cm->DeregisterCallback(token);
call_opts->ClearCancelCallback();
if (success) {
cm->DeregisterCallback(token);
} else {
return errors::Cancelled("Step was cancelled");
}
// Collects fetches.
Status status = calls.status();
......
......@@ -24,17 +24,6 @@ limitations under the License.
#include "grpc++/impl/codegen/service_type.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::CreateSessionRequest);
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::ExtendSessionRequest);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepRequest);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepResponse);
namespace tensorflow {
namespace grpc {
......
......@@ -25,8 +25,18 @@ limitations under the License.
#include "grpc++/impl/codegen/stub_options.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
#include "tensorflow/core/protobuf/master.pb.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::CreateSessionRequest);
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::ExtendSessionRequest);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepRequest);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunStepResponse);
namespace grpc {
class CompletionQueue;
class Channel;
......
......@@ -169,6 +169,9 @@ class GrpcRemoteWorker : public WorkerInterface {
AsyncMethod<RequestMessage, ResponseMessage> async_method,
StatusCallback done, CallOptions* call_opts = nullptr) {
::grpc::ClientContext* context = new ::grpc::ClientContext;
// The initialization and recovery protocols rely on blocking
// until we get a response.
context->set_fail_fast(false);
if (call_opts) {
call_opts->SetCancelCallback([context]() { context->TryCancel(); });
}
......
......@@ -152,7 +152,10 @@ class UnlimitedSizeProtoSerializationTraits {
bool* own_buffer) {
*own_buffer = true;
int byte_size = msg.ByteSize();
if (byte_size <= tensorflow_helper::kGrpcBufferWriterMaxBufferLength) {
if (byte_size < 0) {
return Status(StatusCode::INTERNAL, "Message length was negative");
} else if (byte_size <=
tensorflow_helper::kGrpcBufferWriterMaxBufferLength) {
gpr_slice slice = g_core_codegen_interface->gpr_slice_malloc(byte_size);
GPR_CODEGEN_ASSERT(
GPR_SLICE_END_PTR(slice) ==
......
......@@ -24,17 +24,6 @@ limitations under the License.
#include "grpc++/impl/codegen/service_type.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RegisterGraphRequest);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphRequest);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphResponse);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RecvTensorResponse);
namespace tensorflow {
namespace grpc {
......
......@@ -25,8 +25,18 @@ limitations under the License.
#include "grpc++/impl/codegen/stub_options.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
#include "tensorflow/core/protobuf/worker.pb.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RegisterGraphRequest);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphRequest);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RunGraphResponse);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE(tensorflow::RecvTensorResponse);
namespace grpc {
class CompletionQueue;
class Channel;
......
......@@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() {
return s;
}
/* static */
size_t TensorSliceWriter::MaxBytesPerElement(DataType dt) {
switch (dt) {
case DT_FLOAT:
return 4;
case DT_DOUBLE:
return 8;
case DT_INT32:
return 10;
case DT_UINT8:
return 2;
case DT_INT16:
return 10;
case DT_INT8:
return 10;
case DT_COMPLEX64:
return 8;
case DT_INT64:
return 10;
case DT_BOOL:
return 1;
case DT_QINT8:
return 10;
case DT_QUINT8:
return 2;
case DT_QINT32:
return 10;
case DT_QINT16:
return 10;
case DT_QUINT16:
return 3;
case DT_UINT16:
return 3;
case DT_COMPLEX128:
return 16;
case DT_HALF:
return 3;
case DT_INVALID:
case DT_STRING:
case DT_BFLOAT16:
default:
CHECK(false) << "MaxBytesPerElement not implemented for dtype: " << dt;
}
return 0;
}
template <>
Status TensorSliceWriter::SaveData(const string* data, int num_elements,
SavedSlice* ss) {
size_t size_bound = ss->ByteSize() + kTensorProtoHeaderBytes +
(num_elements * MaxBytesPerElement(DT_INT32));
for (int i = 0; i < num_elements; ++i) {
size_bound += data[i].size();
}
if (size_bound > kMaxMessageBytes) {
return errors::InvalidArgument(
"Tensor slice is too large to serialize (conservative estimate: ",
size_bound, " bytes)");
}
Fill(data, num_elements, ss->mutable_data());
DCHECK_GE(ss->ByteSize(), 0);
DCHECK_LE(ss->ByteSize(), size_bound);
return Status::OK();
}
} // namespace checkpoint
} // namespace tensorflow
......@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
......@@ -61,11 +62,24 @@ class TensorSliceWriter {
const TensorSlice& slice, const T* data);
Status Finish();
private:
// Allocate "num_elements" elements in "ss" and save the data in "data"
// there.
template <typename T>
static void SaveData(const T* data, int num_elements, SavedSlice* ss);
static Status SaveData(const T* data, int num_elements, SavedSlice* ss);
static size_t MaxBytesPerElement(DataType dt);
private:
static const size_t kMaxMessageBytes = 1LL << 31;
// Filling in the TensorProto in a SavedSlice will add the following
// header bytes, in addition to the data:
// - 1 byte: TensorProto tag and wire format
// - <= 5 bytes: TensorProto length
// - 1 byte: Repeated *_val tag and wire format
// - <= 5 bytes: *_val length
// However, we add 1KB of slack, to be conservative and guard
// against other additions to the TensorProto.
static const size_t kTensorProtoHeaderBytes = 1 << 10;
const string filename_;
const CreateBuilderFunction create_builder_;
......@@ -132,7 +146,7 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
TensorShape saved_shape(ssm->shape());
TensorShape sliced_shape;
TF_RETURN_IF_ERROR(slice.SliceTensorShape(saved_shape, &sliced_shape));
SaveData(data, sliced_shape.num_elements(), ss);
TF_RETURN_IF_ERROR(SaveData(data, sliced_shape.num_elements(), ss));
string key = EncodeTensorNameSlice(name, slice);
// TODO(yangke): consider doing a two-pass thing where the first pass just
// list the tensor slices we want to save and then another pass to actually
......@@ -148,11 +162,26 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
}
template <typename T>
void TensorSliceWriter::SaveData(const T* data, int num_elements,
SavedSlice* ss) {
Status TensorSliceWriter::SaveData(const T* data, int num_elements,
SavedSlice* ss) {
size_t size_bound =
ss->ByteSize() + kTensorProtoHeaderBytes +
(MaxBytesPerElement(DataTypeToEnum<T>::value) * num_elements);
if (size_bound > kMaxMessageBytes) {
return errors::InvalidArgument(
"Tensor slice is too large to serialize (conservative estimate: ",
size_bound, " bytes)");
}
Fill(data, num_elements, ss->mutable_data());
DCHECK_GE(ss->ByteSize(), 0);
DCHECK_LE(ss->ByteSize(), size_bound);
return Status::OK();
}
template <>
Status TensorSliceWriter::SaveData(const string* data, int num_elements,
SavedSlice* ss);
// Create a table builder that will write to "filename" in
// tensorflow::io::Table format. If successful, return OK
// and set "*builder" to the allocated builder. Otherwise, return a
......
......@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/util/tensor_slice_writer.h"
#include <array>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
......@@ -263,6 +265,85 @@ void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
}
}
template <typename DT>
size_t BytesPerElementHelper(DT value) {
SavedSlice ss;
std::array<DT, 1> lo_data;
std::fill(lo_data.begin(), lo_data.end(), value);
TensorSliceWriter::SaveData(lo_data.data(), lo_data.size(), &ss);
int lo_byte_size = ss.ByteSize();
std::array<DT, 1001> hi_data;
std::fill(hi_data.begin(), hi_data.end(), value);
TensorSliceWriter::SaveData(hi_data.data(), hi_data.size(), &ss);
int hi_byte_size = ss.ByteSize();
return (hi_byte_size - lo_byte_size) / (hi_data.size() - lo_data.size());
}
TEST(TensorSliceWriteTest, CheckpointSize) {
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
BytesPerElementHelper<bool>(false));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_BOOL),
BytesPerElementHelper<bool>(true));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_FLOAT),
BytesPerElementHelper<float>(-1.0));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_DOUBLE),
BytesPerElementHelper<double>(-1.0));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX64),
BytesPerElementHelper<complex64>(-1.0));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_COMPLEX128),
BytesPerElementHelper<complex128>(-1.0));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT32),
BytesPerElementHelper<int32>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT64),
BytesPerElementHelper<int64>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT16),
BytesPerElementHelper<uint16>(std::numeric_limits<uint16>::max()));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_UINT8),
BytesPerElementHelper<uint8>(std::numeric_limits<uint8>::max()));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT8),
BytesPerElementHelper<int8>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_INT16),
BytesPerElementHelper<int16>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT8),
BytesPerElementHelper<qint8>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QUINT8),
BytesPerElementHelper<quint8>(std::numeric_limits<uint8>::max()));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_QINT32),
BytesPerElementHelper<qint32>(-1));
EXPECT_EQ(TensorSliceWriter::MaxBytesPerElement(DT_HALF),
BytesPerElementHelper<Eigen::half>(Eigen::half(-1.0)));
}
TEST(TensorSliceWriteTest, SizeErrors) {
const string filename = io::JoinPath(testing::TmpDir(), "checkpoint");
TensorSliceWriter writer(filename, CreateTableTensorSliceBuilder);
// Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
{
TensorShape shape({300, 1000000});
TensorSlice slice = TensorSlice::ParseOrDie("-:-");
const std::vector<int8> data(300000000, -1);
Status s = writer.Add("test1", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
EXPECT_TRUE(StringPiece(s.error_message())
.contains("Tensor slice is too large to serialize"));
}
// Add a large string tensor slice, which will fail.
{
TensorShape shape({100, 1000000});
TensorSlice slice = TensorSlice::ParseOrDie("-:-");
const std::vector<string> data(100000000, "rhubarbrhubarb");
Status s = writer.Add("test2", shape, slice, data.data());
EXPECT_EQ(s.code(), error::INVALID_ARGUMENT);
EXPECT_TRUE(StringPiece(s.error_message())
.contains("Tensor slice is too large to serialize"));
}
}
} // namespace checkpoint
} // namespace tensorflow
......@@ -276,6 +276,15 @@ class QueueBase(object):
If the queue is full when this operation executes, it will block
until the element has been enqueued.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.AbortedError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary containing
the values to enqueue.
......@@ -305,6 +314,15 @@ class QueueBase(object):
If the queue is full when this operation executes, it will block
until all of the elements have been enqueued.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.AbortedError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary
from which the queue elements are taken.
......@@ -357,6 +375,14 @@ class QueueBase(object):
If the queue is empty when this operation executes, it will block
until there is an element to dequeue.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed, the queue is empty, and there are no pending
enqueue operations that can fulfil this request,
`tf.errors.OutOfRangeError` will be raised. If the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
name: A name for the operation (optional).
......@@ -386,6 +412,14 @@ class QueueBase(object):
If the queue is closed and there are less than `n` elements left, then an
`OutOfRange` exception is raised.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed, the queue contains fewer than `n` elements, and
there are no pending enqueue operations that can fulfil this
request, `tf.errors.OutOfRangeError` will be raised. If the
session is [closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
name: A name for the operation (optional).
......@@ -412,18 +446,20 @@ class QueueBase(object):
"""Dequeues and concatenates `n` elements from this queue.
**Note** This operation is not supported by all queues. If a queue does not
support DequeueUpTo, then an Unimplemented exception is raised.
This operation concatenates queue-element component tensors along the
0th dimension to make a single component tensor. All of the components
in the dequeued tuple will have size `n` in the 0th dimension.
If the queue is closed and there are more than `0` but less than `n`
elements remaining, then instead of raising an `OutOfRange` exception like
`dequeue_many`, the remaining elements are returned immediately.
If the queue is closed and there are `0` elements left in the queue, then
an `OutOfRange` exception is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`:
support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
This operation concatenates queue-element component tensors along
the 0th dimension to make a single component tensor. If the queue
has not been closed, all of the components in the dequeued tuple
will have size `n` in the 0th dimension.
If the queue is closed and there are more than `0` but fewer than
`n` elements remaining, then instead of raising a
`tf.errors.OutOfRangeError` like [`dequeue_many`](#QueueBase.dequeue_many),
the remaining elements are returned immediately. If the queue is
closed and there are `0` elements left in the queue, then a
`tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
......
......@@ -287,6 +287,33 @@ class SaverTest(tf.test.TestCase):
expected_save_path = "%s-%d" % (save_path, global_step_int)
self.assertEqual(expected_save_path, val)
def testLargeVariable(self):
save_path = os.path.join(self.get_temp_dir(), "large_variable")
with tf.Session("", graph=tf.Graph()) as sess:
# Declare a variable larger than 2GB.
with tf.device("/cpu:0"):
var = tf.Variable(tf.constant(-1, shape=[300, 1000000], dtype=tf.int8))
save = tf.train.Saver({var.op.name: var})
var.initializer.run()
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError,
"Tensor slice is too large to serialize"):
save.save(sess, save_path)
with tf.Session("", graph=tf.Graph()) as sess:
# Declare a variable that is exactly 2GB. This should fail,
# because a serialized checkpoint includes other header
# metadata.
with tf.device("/cpu:0"):
var = tf.Variable(
tf.constant(False, shape=[2, 1024, 1024, 1024], dtype=tf.bool))
save = tf.train.Saver({var.op.name: var})
var.initializer.run()
with self.assertRaisesRegexp(
tf.errors.InvalidArgumentError,
"Tensor slice is too large to serialize"):
save.save(sess, save_path)
class SaveRestoreShardedTest(tf.test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册