Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
df798f75
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
df798f75
编写于
6月 22, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
6月 22, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Use fix graph for dropout if is_training has a constant value.
Change: 125571683
上级
91757d7b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
58 addition
and
5 deletion
+58
-5
tensorflow/contrib/layers/python/layers/layers.py
tensorflow/contrib/layers/python/layers/layers.py
+12
-5
tensorflow/contrib/layers/python/layers/layers_test.py
tensorflow/contrib/layers/python/layers/layers_test.py
+18
-0
tensorflow/contrib/layers/python/layers/utils.py
tensorflow/contrib/layers/python/layers/utils.py
+28
-0
未找到文件。
tensorflow/contrib/layers/python/layers/layers.py
浏览文件 @
df798f75
...
...
@@ -402,11 +402,18 @@ def dropout(inputs,
a tensor representing the output of the operation.
"""
with
ops
.
op_scope
([
inputs
],
scope
,
'Dropout'
)
as
sc
:
is_training
=
ops
.
convert_to_tensor
(
is_training
)
outputs
=
control_flow_ops
.
cond
(
is_training
,
lambda
:
nn
.
dropout
(
inputs
,
keep_prob
,
noise_shape
),
lambda
:
inputs
)
is_training_value
=
utils
.
constant_value
(
is_training
,
dtypes
.
bool
)
if
is_training_value
is
not
None
:
if
is_training_value
:
outputs
=
nn
.
dropout
(
inputs
,
keep_prob
,
noise_shape
)
else
:
outputs
=
inputs
else
:
def
_dropout
():
return
nn
.
dropout
(
inputs
,
keep_prob
,
noise_shape
)
outputs
=
control_flow_ops
.
cond
(
is_training
,
_dropout
,
lambda
:
inputs
)
return
utils
.
collect_named_outputs
(
outputs_collections
,
sc
,
outputs
)
...
...
tensorflow/contrib/layers/python/layers/layers_test.py
浏览文件 @
df798f75
...
...
@@ -327,6 +327,24 @@ class DropoutTest(tf.test.TestCase):
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
output
=
tf
.
contrib
.
layers
.
dropout
(
images
)
self
.
assertEquals
(
output
.
op
.
name
,
'Dropout/dropout/mul_1'
)
output
.
get_shape
().
assert_is_compatible_with
(
images
.
get_shape
())
def
testCreateDropoutWithConstant
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
is_training
=
tf
.
constant
(
False
)
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
output
=
tf
.
contrib
.
layers
.
dropout
(
images
,
is_training
=
is_training
)
self
.
assertEquals
(
output
.
op
.
name
,
'Dropout/dropout/mul_1'
)
output
.
get_shape
().
assert_is_compatible_with
(
images
.
get_shape
())
def
testCreateDropoutWithPlaceholder
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
is_training
=
tf
.
placeholder
(
dtype
=
tf
.
bool
,
shape
=
[])
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
)
output
=
tf
.
contrib
.
layers
.
dropout
(
images
,
is_training
=
is_training
)
self
.
assertEquals
(
output
.
op
.
name
,
'Dropout/cond/Merge'
)
output
.
get_shape
().
assert_is_compatible_with
(
images
.
get_shape
())
...
...
tensorflow/contrib/layers/python/layers/utils.py
浏览文件 @
df798f75
...
...
@@ -52,6 +52,34 @@ def collect_named_outputs(collections, name, outputs):
return
outputs
def
constant_value
(
value_or_tensor
,
tensor_dtype
=
None
):
"""Returns value if value_or_tensor has a constant value.
Args:
value_or_tensor: A value or a `Tensor`.
tensor_dtype: Optional `tf.dtype`, if set it would check the tensor type.
Returns:
The constant value or None if it not constant.
Raises:
ValueError: if value_or_tensor is None or the tensor has the wrong dtype.
"""
if
value_or_tensor
is
None
:
raise
ValueError
(
'value_or_tensor cannot be None'
)
value
=
value_or_tensor
if
isinstance
(
value_or_tensor
,
ops
.
Tensor
):
if
tensor_dtype
and
value_or_tensor
.
dtype
!=
tensor_dtype
:
raise
ValueError
(
'The tensor has the wrong type %s instead of %s'
%
(
value_or_tensor
.
dtype
,
tensor_dtype
))
if
value_or_tensor
.
op
.
type
==
'Const'
:
value_or_tensor
.
graph
.
prevent_feeding
(
value_or_tensor
)
value
=
value_or_tensor
.
op
.
get_attr
(
'value'
)
else
:
value
=
None
return
value
def
get_variable_collections
(
variables_collections
,
name
):
if
isinstance
(
variables_collections
,
dict
):
variable_collections
=
variables_collections
[
name
]
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录