Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
ac96df2d
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ac96df2d
编写于
4月 07, 2016
作者:
S
Sherry Moore
提交者:
TensorFlower Gardener
4月 07, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added is_variable_initialized(variable) function.
Change: 119321281
上级
a0bc7959
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
71 addition
and
1 deletion
+71
-1
tensorflow/core/kernels/variable_ops.cc
tensorflow/core/kernels/variable_ops.cc
+8
-1
tensorflow/core/kernels/variable_ops.h
tensorflow/core/kernels/variable_ops.h
+16
-0
tensorflow/core/ops/state_ops.cc
tensorflow/core/ops/state_ops.cc
+14
-0
tensorflow/python/kernel_tests/variable_ops_test.py
tensorflow/python/kernel_tests/variable_ops_test.py
+8
-0
tensorflow/python/ops/state_ops.py
tensorflow/python/ops/state_ops.py
+3
-0
tensorflow/python/ops/variables.py
tensorflow/python/ops/variables.py
+12
-0
tensorflow/python/training/session_manager_test.py
tensorflow/python/training/session_manager_test.py
+10
-0
未找到文件。
tensorflow/core/kernels/variable_ops.cc
浏览文件 @
ac96df2d
...
...
@@ -28,6 +28,8 @@ REGISTER_KERNEL_BUILDER(Name("TemporaryVariable").Device(DEVICE_CPU),
TemporaryVariableOp
);
REGISTER_KERNEL_BUILDER
(
Name
(
"DestroyTemporaryVariable"
).
Device
(
DEVICE_CPU
),
DestroyTemporaryVariableOp
);
REGISTER_KERNEL_BUILDER
(
Name
(
"IsVariableInitialized"
).
Device
(
DEVICE_CPU
),
IsVariableInitializedOp
);
#if GOOGLE_CUDA
// Only register 'Variable' on GPU for the subset of types also supported by
...
...
@@ -43,7 +45,12 @@ REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DestroyTemporaryVariable") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("T"), \
DestroyTemporaryVariableOp);
DestroyTemporaryVariableOp); \
REGISTER_KERNEL_BUILDER(Name("IsVariableInitialized") \
.Device(DEVICE_GPU) \
.TypeConstraint<type>("dtype") \
.HostMemory("is_initialized"), \
IsVariableInitializedOp);
TF_CALL_GPU_NUMBER_TYPES
(
REGISTER_GPU_KERNELS
);
#undef REGISTER_GPU_KERNELS
...
...
tensorflow/core/kernels/variable_ops.h
浏览文件 @
ac96df2d
...
...
@@ -158,6 +158,22 @@ class DestroyTemporaryVariableOp : public OpKernel {
string
var_name_
;
};
class
IsVariableInitializedOp
:
public
OpKernel
{
public:
IsVariableInitializedOp
(
OpKernelConstruction
*
context
)
:
OpKernel
(
context
)
{}
void
Compute
(
OpKernelContext
*
context
)
override
{
// Get a mutable input tensor of the Ref input.
const
Tensor
&
input_tensor
=
context
->
mutable_input
(
0
,
false
);
Tensor
*
output
=
nullptr
;
OP_REQUIRES_OK
(
context
,
context
->
allocate_output
(
0
,
TensorShape
({}),
&
output
));
auto
output_tensor
=
output
->
tensor
<
bool
,
0
>
();
bool
result
=
input_tensor
.
IsInitialized
();
output_tensor
()
=
result
;
}
};
}
// namespace tensorflow
#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_
tensorflow/core/ops/state_ops.cc
浏览文件 @
ac96df2d
...
...
@@ -40,6 +40,20 @@ shared_name: If non-empty, this variable is named in the given bucket
with this shared_name. Otherwise, the node name is used instead.
)doc"
);
REGISTER_OP
(
"IsVariableInitialized"
)
.
Output
(
"is_initialized: bool"
)
.
Input
(
"ref: Ref(dtype)"
)
.
Attr
(
"dtype: type"
)
.
SetAllowsUninitializedInput
()
.
Doc
(
R"doc(
Checks whether a tensor has been initialized.
Outputs boolean scalar indicating whether the tensor has been initialized.
ref: Should be from a `Variable` node. May be uninitialized.
dtype: The type of elements in the variable tensor.
)doc"
);
REGISTER_OP
(
"TemporaryVariable"
)
.
Output
(
"ref: Ref(dtype)"
)
.
Attr
(
"shape: shape"
)
...
...
tensorflow/python/kernel_tests/variable_ops_test.py
浏览文件 @
ac96df2d
...
...
@@ -237,6 +237,14 @@ class VariableOpTest(tf.test.TestCase):
result
=
tf
.
mul
(
var
,
var
)
self
.
assertAllClose
([
4.0
],
result
.
eval
())
def
testIsVariableInitialized
(
self
):
for
use_gpu
in
[
True
,
False
]:
with
self
.
test_session
(
use_gpu
=
use_gpu
):
v0
=
state_ops
.
variable_op
([
1
,
2
],
tf
.
float32
)
self
.
assertEqual
(
False
,
tf
.
is_variable_initialized
(
v0
).
eval
())
tf
.
assign
(
v0
,
[[
2.0
,
3.0
]]).
eval
()
self
.
assertEqual
(
True
,
tf
.
is_variable_initialized
(
v0
).
eval
())
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
tensorflow/python/ops/state_ops.py
浏览文件 @
ac96df2d
...
...
@@ -30,6 +30,7 @@ collected in the graph.
@@initialize_all_variables
@@initialize_variables
@@initialize_local_variables
@@is_variable_initialized
@@assert_variables_initialized
## Saving and Restoring Variables
...
...
@@ -134,6 +135,8 @@ def variable_op(shape, dtype, name="Variable", set_shape=True, container="",
# NOTE(mrry): Shapes are conditionally set in the Python wrapper.
ops
.
RegisterShape
(
"Variable"
)(
common_shapes
.
unknown_shape
)
ops
.
RegisterShape
(
"IsVariableInitialized"
)(
common_shapes
.
scalar_shape
)
@
ops
.
RegisterShape
(
"TemporaryVariable"
)
def
_TemporaryVariableShape
(
op
):
...
...
tensorflow/python/ops/variables.py
浏览文件 @
ac96df2d
...
...
@@ -798,6 +798,18 @@ def initialize_local_variables():
return
initialize_variables
(
local_variables
())
def
is_variable_initialized
(
variable
):
"""Returns an Op to check if a variable has been initialized.
Args:
variable: A `Variable`.
Returns:
An operation to check whether a variable has been initialized.
"""
return
state_ops
.
is_variable_initialized
(
variable
)
def
assert_variables_initialized
(
var_list
=
None
):
"""Returns an Op to check if variables are initialized.
...
...
tensorflow/python/training/session_manager_test.py
浏览文件 @
ac96df2d
...
...
@@ -71,6 +71,8 @@ class SessionManagerTest(tf.test.TestCase):
os
.
rename
(
checkpoint_dir
,
checkpoint_dir2
)
gfile
.
MakeDirs
(
checkpoint_dir
)
v
=
tf
.
Variable
([
6.0
,
7.0
,
8.0
],
name
=
"v"
)
with
self
.
test_session
():
self
.
assertEqual
(
False
,
tf
.
is_variable_initialized
(
v
).
eval
())
tf
.
train
.
SessionManager
(
ready_op
=
tf
.
assert_variables_initialized
())
saver
=
tf
.
train
.
Saver
({
"v"
:
v
})
# This should fail as there's no checkpoint within 2 seconds.
...
...
@@ -85,6 +87,9 @@ class SessionManagerTest(tf.test.TestCase):
sess
=
sm
.
prepare_session
(
""
,
init_op
=
None
,
saver
=
saver
,
checkpoint_dir
=
checkpoint_dir
,
wait_for_checkpoint
=
True
,
max_wait_secs
=
2
)
self
.
assertEqual
(
True
,
tf
.
is_variable_initialized
(
sess
.
graph
.
get_tensor_by_name
(
"v:0"
)).
eval
(
session
=
sess
))
def
testRecoverSession
(
self
):
# Create a checkpoint.
...
...
@@ -109,11 +114,16 @@ class SessionManagerTest(tf.test.TestCase):
# Create a new Graph and SessionManager and recover.
with
tf
.
Graph
().
as_default
():
v
=
tf
.
Variable
(
2
,
name
=
"v"
)
with
self
.
test_session
():
self
.
assertEqual
(
False
,
tf
.
is_variable_initialized
(
v
).
eval
())
sm2
=
tf
.
train
.
SessionManager
(
ready_op
=
tf
.
assert_variables_initialized
())
saver
=
tf
.
train
.
Saver
({
"v"
:
v
})
sess
,
initialized
=
sm2
.
recover_session
(
""
,
saver
=
saver
,
checkpoint_dir
=
checkpoint_dir
)
self
.
assertTrue
(
initialized
)
self
.
assertEqual
(
True
,
tf
.
is_variable_initialized
(
sess
.
graph
.
get_tensor_by_name
(
"v:0"
)).
eval
(
session
=
sess
))
self
.
assertEquals
(
1
,
sess
.
run
(
v
))
def
testWaitForSessionReturnsNoneAfterTimeout
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录