Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
e13cd053
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
e13cd053
编写于
6月 15, 2016
作者:
M
Martin Wicke
提交者:
GitHub
6月 15, 2016
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2891 from mrry/r0.9-cherrypick
Cherry-picking stability and doc fixes for r0.9
上级
ea00daa3
0d0f37ae
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
288 addition
and
42 deletion
+288
-42
tensorflow/core/distributed_runtime/master_session.cc
tensorflow/core/distributed_runtime/master_session.cc
+6
-2
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
.../core/distributed_runtime/rpc/grpc_master_service_impl.cc
+0
-11
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
...w/core/distributed_runtime/rpc/grpc_master_service_impl.h
+10
-0
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
...orflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
+3
-0
tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h
.../core/distributed_runtime/rpc/grpc_serialization_traits.h
+4
-1
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
.../core/distributed_runtime/rpc/grpc_worker_service_impl.cc
+0
-11
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
...w/core/distributed_runtime/rpc/grpc_worker_service_impl.h
+10
-0
tensorflow/core/util/tensor_slice_writer.cc
tensorflow/core/util/tensor_slice_writer.cc
+65
-0
tensorflow/core/util/tensor_slice_writer.h
tensorflow/core/util/tensor_slice_writer.h
+34
-5
tensorflow/core/util/tensor_slice_writer_test.cc
tensorflow/core/util/tensor_slice_writer_test.cc
+81
-0
tensorflow/python/ops/data_flow_ops.py
tensorflow/python/ops/data_flow_ops.py
+48
-12
tensorflow/python/training/saver_test.py
tensorflow/python/training/saver_test.py
+27
-0
未找到文件。
tensorflow/core/distributed_runtime/master_session.cc
浏览文件 @
e13cd053
...
...
@@ -569,11 +569,15 @@ Status MasterSession::ReffedClientGraph::RunPartitions(
bool
success
=
cm
->
RegisterCallback
(
token
,
[
&
calls
]()
{
calls
.
StartCancel
();
});
if
(
!
success
)
{
return
errors
::
Cancelled
(
"Step was cancelled"
);
calls
.
StartCancel
(
);
}
calls
.
Wait
();
cm
->
DeregisterCallback
(
token
);
call_opts
->
ClearCancelCallback
();
if
(
success
)
{
cm
->
DeregisterCallback
(
token
);
}
else
{
return
errors
::
Cancelled
(
"Step was cancelled"
);
}
// Collects fetches.
Status
status
=
calls
.
status
();
...
...
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.cc
浏览文件 @
e13cd053
...
...
@@ -24,17 +24,6 @@ limitations under the License.
#include "grpc++/impl/codegen/service_type.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
CreateSessionRequest
);
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
ExtendSessionRequest
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunStepRequest
);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunStepResponse
);
namespace
tensorflow
{
namespace
grpc
{
...
...
tensorflow/core/distributed_runtime/rpc/grpc_master_service_impl.h
浏览文件 @
e13cd053
...
...
@@ -25,8 +25,18 @@ limitations under the License.
#include "grpc++/impl/codegen/stub_options.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
#include "tensorflow/core/protobuf/master.pb.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
CreateSessionRequest
);
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
ExtendSessionRequest
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunStepRequest
);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunStepResponse
);
namespace
grpc
{
class
CompletionQueue
;
class
Channel
;
...
...
tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc
浏览文件 @
e13cd053
...
...
@@ -169,6 +169,9 @@ class GrpcRemoteWorker : public WorkerInterface {
AsyncMethod
<
RequestMessage
,
ResponseMessage
>
async_method
,
StatusCallback
done
,
CallOptions
*
call_opts
=
nullptr
)
{
::
grpc
::
ClientContext
*
context
=
new
::
grpc
::
ClientContext
;
// The initialization and recovery protocols rely on blocking
// until we get a response.
context
->
set_fail_fast
(
false
);
if
(
call_opts
)
{
call_opts
->
SetCancelCallback
([
context
]()
{
context
->
TryCancel
();
});
}
...
...
tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h
浏览文件 @
e13cd053
...
...
@@ -152,7 +152,10 @@ class UnlimitedSizeProtoSerializationTraits {
bool
*
own_buffer
)
{
*
own_buffer
=
true
;
int
byte_size
=
msg
.
ByteSize
();
if
(
byte_size
<=
tensorflow_helper
::
kGrpcBufferWriterMaxBufferLength
)
{
if
(
byte_size
<
0
)
{
return
Status
(
StatusCode
::
INTERNAL
,
"Message length was negative"
);
}
else
if
(
byte_size
<=
tensorflow_helper
::
kGrpcBufferWriterMaxBufferLength
)
{
gpr_slice
slice
=
g_core_codegen_interface
->
gpr_slice_malloc
(
byte_size
);
GPR_CODEGEN_ASSERT
(
GPR_SLICE_END_PTR
(
slice
)
==
...
...
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc
浏览文件 @
e13cd053
...
...
@@ -24,17 +24,6 @@ limitations under the License.
#include "grpc++/impl/codegen/service_type.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RegisterGraphRequest
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunGraphRequest
);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunGraphResponse
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RecvTensorResponse
);
namespace
tensorflow
{
namespace
grpc
{
...
...
tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h
浏览文件 @
e13cd053
...
...
@@ -25,8 +25,18 @@ limitations under the License.
#include "grpc++/impl/codegen/stub_options.h"
#include "grpc++/impl/codegen/sync_stream.h"
#include "tensorflow/core/distributed_runtime/rpc/grpc_serialization_traits.h"
#include "tensorflow/core/protobuf/worker.pb.h"
// Contains potentially large GraphDef.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RegisterGraphRequest
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunGraphRequest
);
// Contains potentially large StepStats, TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RunGraphResponse
);
// Contains potentially large TensorProto.
TF_GRPC_ALLOW_UNLIMITED_MESSAGE_SIZE
(
tensorflow
::
RecvTensorResponse
);
namespace
grpc
{
class
CompletionQueue
;
class
Channel
;
...
...
tensorflow/core/util/tensor_slice_writer.cc
浏览文件 @
e13cd053
...
...
@@ -126,6 +126,71 @@ Status TensorSliceWriter::Finish() {
return
s
;
}
/* static */
size_t
TensorSliceWriter
::
MaxBytesPerElement
(
DataType
dt
)
{
switch
(
dt
)
{
case
DT_FLOAT
:
return
4
;
case
DT_DOUBLE
:
return
8
;
case
DT_INT32
:
return
10
;
case
DT_UINT8
:
return
2
;
case
DT_INT16
:
return
10
;
case
DT_INT8
:
return
10
;
case
DT_COMPLEX64
:
return
8
;
case
DT_INT64
:
return
10
;
case
DT_BOOL
:
return
1
;
case
DT_QINT8
:
return
10
;
case
DT_QUINT8
:
return
2
;
case
DT_QINT32
:
return
10
;
case
DT_QINT16
:
return
10
;
case
DT_QUINT16
:
return
3
;
case
DT_UINT16
:
return
3
;
case
DT_COMPLEX128
:
return
16
;
case
DT_HALF
:
return
3
;
case
DT_INVALID
:
case
DT_STRING
:
case
DT_BFLOAT16
:
default:
CHECK
(
false
)
<<
"MaxBytesPerElement not implemented for dtype: "
<<
dt
;
}
return
0
;
}
template
<
>
Status
TensorSliceWriter
::
SaveData
(
const
string
*
data
,
int
num_elements
,
SavedSlice
*
ss
)
{
size_t
size_bound
=
ss
->
ByteSize
()
+
kTensorProtoHeaderBytes
+
(
num_elements
*
MaxBytesPerElement
(
DT_INT32
));
for
(
int
i
=
0
;
i
<
num_elements
;
++
i
)
{
size_bound
+=
data
[
i
].
size
();
}
if
(
size_bound
>
kMaxMessageBytes
)
{
return
errors
::
InvalidArgument
(
"Tensor slice is too large to serialize (conservative estimate: "
,
size_bound
,
" bytes)"
);
}
Fill
(
data
,
num_elements
,
ss
->
mutable_data
());
DCHECK_GE
(
ss
->
ByteSize
(),
0
);
DCHECK_LE
(
ss
->
ByteSize
(),
size_bound
);
return
Status
::
OK
();
}
}
// namespace checkpoint
}
// namespace tensorflow
tensorflow/core/util/tensor_slice_writer.h
浏览文件 @
e13cd053
...
...
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
...
...
@@ -61,11 +62,24 @@ class TensorSliceWriter {
const
TensorSlice
&
slice
,
const
T
*
data
);
Status
Finish
();
private:
// Allocate "num_elements" elements in "ss" and save the data in "data"
// there.
template
<
typename
T
>
static
void
SaveData
(
const
T
*
data
,
int
num_elements
,
SavedSlice
*
ss
);
static
Status
SaveData
(
const
T
*
data
,
int
num_elements
,
SavedSlice
*
ss
);
static
size_t
MaxBytesPerElement
(
DataType
dt
);
private:
static
const
size_t
kMaxMessageBytes
=
1LL
<<
31
;
// Filling in the TensorProto in a SavedSlice will add the following
// header bytes, in addition to the data:
// - 1 byte: TensorProto tag and wire format
// - <= 5 bytes: TensorProto length
// - 1 byte: Repeated *_val tag and wire format
// - <= 5 bytes: *_val length
// However, we add 1KB of slack, to be conservative and guard
// against other additions to the TensorProto.
static
const
size_t
kTensorProtoHeaderBytes
=
1
<<
10
;
const
string
filename_
;
const
CreateBuilderFunction
create_builder_
;
...
...
@@ -132,7 +146,7 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
TensorShape
saved_shape
(
ssm
->
shape
());
TensorShape
sliced_shape
;
TF_RETURN_IF_ERROR
(
slice
.
SliceTensorShape
(
saved_shape
,
&
sliced_shape
));
SaveData
(
data
,
sliced_shape
.
num_elements
(),
ss
);
TF_RETURN_IF_ERROR
(
SaveData
(
data
,
sliced_shape
.
num_elements
(),
ss
)
);
string
key
=
EncodeTensorNameSlice
(
name
,
slice
);
// TODO(yangke): consider doing a two-pass thing where the first pass just
// list the tensor slices we want to save and then another pass to actually
...
...
@@ -148,11 +162,26 @@ Status TensorSliceWriter::Add(const string& name, const TensorShape& shape,
}
template
<
typename
T
>
void
TensorSliceWriter
::
SaveData
(
const
T
*
data
,
int
num_elements
,
SavedSlice
*
ss
)
{
Status
TensorSliceWriter
::
SaveData
(
const
T
*
data
,
int
num_elements
,
SavedSlice
*
ss
)
{
size_t
size_bound
=
ss
->
ByteSize
()
+
kTensorProtoHeaderBytes
+
(
MaxBytesPerElement
(
DataTypeToEnum
<
T
>::
value
)
*
num_elements
);
if
(
size_bound
>
kMaxMessageBytes
)
{
return
errors
::
InvalidArgument
(
"Tensor slice is too large to serialize (conservative estimate: "
,
size_bound
,
" bytes)"
);
}
Fill
(
data
,
num_elements
,
ss
->
mutable_data
());
DCHECK_GE
(
ss
->
ByteSize
(),
0
);
DCHECK_LE
(
ss
->
ByteSize
(),
size_bound
);
return
Status
::
OK
();
}
template
<
>
Status
TensorSliceWriter
::
SaveData
(
const
string
*
data
,
int
num_elements
,
SavedSlice
*
ss
);
// Create a table builder that will write to "filename" in
// tensorflow::io::Table format. If successful, return OK
// and set "*builder" to the allocated builder. Otherwise, return a
...
...
tensorflow/core/util/tensor_slice_writer_test.cc
浏览文件 @
e13cd053
...
...
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/core/util/tensor_slice_writer.h"
#include <array>
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/logging.h"
...
...
@@ -263,6 +265,85 @@ void TensorSliceWriteTestHelper::CheckEntries(const string& fname) {
}
}
template
<
typename
DT
>
size_t
BytesPerElementHelper
(
DT
value
)
{
SavedSlice
ss
;
std
::
array
<
DT
,
1
>
lo_data
;
std
::
fill
(
lo_data
.
begin
(),
lo_data
.
end
(),
value
);
TensorSliceWriter
::
SaveData
(
lo_data
.
data
(),
lo_data
.
size
(),
&
ss
);
int
lo_byte_size
=
ss
.
ByteSize
();
std
::
array
<
DT
,
1001
>
hi_data
;
std
::
fill
(
hi_data
.
begin
(),
hi_data
.
end
(),
value
);
TensorSliceWriter
::
SaveData
(
hi_data
.
data
(),
hi_data
.
size
(),
&
ss
);
int
hi_byte_size
=
ss
.
ByteSize
();
return
(
hi_byte_size
-
lo_byte_size
)
/
(
hi_data
.
size
()
-
lo_data
.
size
());
}
TEST
(
TensorSliceWriteTest
,
CheckpointSize
)
{
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_BOOL
),
BytesPerElementHelper
<
bool
>
(
false
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_BOOL
),
BytesPerElementHelper
<
bool
>
(
true
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_FLOAT
),
BytesPerElementHelper
<
float
>
(
-
1.0
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_DOUBLE
),
BytesPerElementHelper
<
double
>
(
-
1.0
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_COMPLEX64
),
BytesPerElementHelper
<
complex64
>
(
-
1.0
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_COMPLEX128
),
BytesPerElementHelper
<
complex128
>
(
-
1.0
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_INT32
),
BytesPerElementHelper
<
int32
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_INT64
),
BytesPerElementHelper
<
int64
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_UINT16
),
BytesPerElementHelper
<
uint16
>
(
std
::
numeric_limits
<
uint16
>::
max
()));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_UINT8
),
BytesPerElementHelper
<
uint8
>
(
std
::
numeric_limits
<
uint8
>::
max
()));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_INT8
),
BytesPerElementHelper
<
int8
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_INT16
),
BytesPerElementHelper
<
int16
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_QINT8
),
BytesPerElementHelper
<
qint8
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_QUINT8
),
BytesPerElementHelper
<
quint8
>
(
std
::
numeric_limits
<
uint8
>::
max
()));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_QINT32
),
BytesPerElementHelper
<
qint32
>
(
-
1
));
EXPECT_EQ
(
TensorSliceWriter
::
MaxBytesPerElement
(
DT_HALF
),
BytesPerElementHelper
<
Eigen
::
half
>
(
Eigen
::
half
(
-
1.0
)));
}
TEST
(
TensorSliceWriteTest
,
SizeErrors
)
{
const
string
filename
=
io
::
JoinPath
(
testing
::
TmpDir
(),
"checkpoint"
);
TensorSliceWriter
writer
(
filename
,
CreateTableTensorSliceBuilder
);
// Add a 300MB int8 tensor slice, which will fail because it expands to 3GB.
{
TensorShape
shape
({
300
,
1000000
});
TensorSlice
slice
=
TensorSlice
::
ParseOrDie
(
"-:-"
);
const
std
::
vector
<
int8
>
data
(
300000000
,
-
1
);
Status
s
=
writer
.
Add
(
"test1"
,
shape
,
slice
,
data
.
data
());
EXPECT_EQ
(
s
.
code
(),
error
::
INVALID_ARGUMENT
);
EXPECT_TRUE
(
StringPiece
(
s
.
error_message
())
.
contains
(
"Tensor slice is too large to serialize"
));
}
// Add a large string tensor slice, which will fail.
{
TensorShape
shape
({
100
,
1000000
});
TensorSlice
slice
=
TensorSlice
::
ParseOrDie
(
"-:-"
);
const
std
::
vector
<
string
>
data
(
100000000
,
"rhubarbrhubarb"
);
Status
s
=
writer
.
Add
(
"test2"
,
shape
,
slice
,
data
.
data
());
EXPECT_EQ
(
s
.
code
(),
error
::
INVALID_ARGUMENT
);
EXPECT_TRUE
(
StringPiece
(
s
.
error_message
())
.
contains
(
"Tensor slice is too large to serialize"
));
}
}
}
// namespace checkpoint
}
// namespace tensorflow
tensorflow/python/ops/data_flow_ops.py
浏览文件 @
e13cd053
...
...
@@ -276,6 +276,15 @@ class QueueBase(object):
If the queue is full when this operation executes, it will block
until the element has been enqueued.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.AbortedError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary containing
the values to enqueue.
...
...
@@ -305,6 +314,15 @@ class QueueBase(object):
If the queue is full when this operation executes, it will block
until all of the elements have been enqueued.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.AbortedError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary
from which the queue elements are taken.
...
...
@@ -357,6 +375,14 @@ class QueueBase(object):
If the queue is empty when this operation executes, it will block
until there is an element to dequeue.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed, the queue is empty, and there are no pending
enqueue operations that can fulfil this request,
`tf.errors.OutOfRangeError` will be raised. If the session is
[closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
name: A name for the operation (optional).
...
...
@@ -386,6 +412,14 @@ class QueueBase(object):
If the queue is closed and there are less than `n` elements left, then an
`OutOfRange` exception is raised.
At runtime, this operation may raise an error if the queue is
[closed](#QueueBase.close) before or during its execution. If the
queue is closed, the queue contains fewer than `n` elements, and
there are no pending enqueue operations that can fulfil this
request, `tf.errors.OutOfRangeError` will be raised. If the
session is [closed](../../api_docs/python/client.md#Session.close),
`tf.errors.CancelledError` will be raised.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
name: A name for the operation (optional).
...
...
@@ -412,18 +446,20 @@ class QueueBase(object):
"""Dequeues and concatenates `n` elements from this queue.
**Note** This operation is not supported by all queues. If a queue does not
support DequeueUpTo, then an Unimplemented exception is raised.
This operation concatenates queue-element component tensors along the
0th dimension to make a single component tensor. All of the components
in the dequeued tuple will have size `n` in the 0th dimension.
If the queue is closed and there are more than `0` but less than `n`
elements remaining, then instead of raising an `OutOfRange` exception like
`dequeue_many`, the remaining elements are returned immediately.
If the queue is closed and there are `0` elements left in the queue, then
an `OutOfRange` exception is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`:
support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
This operation concatenates queue-element component tensors along
the 0th dimension to make a single component tensor. If the queue
has not been closed, all of the components in the dequeued tuple
will have size `n` in the 0th dimension.
If the queue is closed and there are more than `0` but fewer than
`n` elements remaining, then instead of raising a
`tf.errors.OutOfRangeError` like [`dequeue_many`](#QueueBase.dequeue_many),
the remaining elements are returned immediately. If the queue is
closed and there are `0` elements left in the queue, then a
`tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
...
...
tensorflow/python/training/saver_test.py
浏览文件 @
e13cd053
...
...
@@ -287,6 +287,33 @@ class SaverTest(tf.test.TestCase):
expected_save_path
=
"%s-%d"
%
(
save_path
,
global_step_int
)
self
.
assertEqual
(
expected_save_path
,
val
)
def
testLargeVariable
(
self
):
save_path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"large_variable"
)
with
tf
.
Session
(
""
,
graph
=
tf
.
Graph
())
as
sess
:
# Declare a variable larger than 2GB.
with
tf
.
device
(
"/cpu:0"
):
var
=
tf
.
Variable
(
tf
.
constant
(
-
1
,
shape
=
[
300
,
1000000
],
dtype
=
tf
.
int8
))
save
=
tf
.
train
.
Saver
({
var
.
op
.
name
:
var
})
var
.
initializer
.
run
()
with
self
.
assertRaisesRegexp
(
tf
.
errors
.
InvalidArgumentError
,
"Tensor slice is too large to serialize"
):
save
.
save
(
sess
,
save_path
)
with
tf
.
Session
(
""
,
graph
=
tf
.
Graph
())
as
sess
:
# Declare a variable that is exactly 2GB. This should fail,
# because a serialized checkpoint includes other header
# metadata.
with
tf
.
device
(
"/cpu:0"
):
var
=
tf
.
Variable
(
tf
.
constant
(
False
,
shape
=
[
2
,
1024
,
1024
,
1024
],
dtype
=
tf
.
bool
))
save
=
tf
.
train
.
Saver
({
var
.
op
.
name
:
var
})
var
.
initializer
.
run
()
with
self
.
assertRaisesRegexp
(
tf
.
errors
.
InvalidArgumentError
,
"Tensor slice is too large to serialize"
):
save
.
save
(
sess
,
save_path
)
class
SaveRestoreShardedTest
(
tf
.
test
.
TestCase
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录