Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
6a365818
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
6a365818
编写于
6月 24, 2016
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
6月 24, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add repeat meta-layer.
Change: 125794105
上级
7083e2af
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
71 addition
and
4 deletion
+71
-4
tensorflow/contrib/layers/python/layers/layers.py
tensorflow/contrib/layers/python/layers/layers.py
+49
-4
tensorflow/contrib/layers/python/layers/layers_test.py
tensorflow/contrib/layers/python/layers/layers_test.py
+22
-0
未找到文件。
tensorflow/contrib/layers/python/layers/layers.py
浏览文件 @
6a365818
...
...
@@ -53,6 +53,7 @@ __all__ = ['avg_pool2d',
'one_hot_encoding'
,
'relu'
,
'relu6'
,
'repeat'
,
'stack'
,
'legacy_fully_connected'
,
'legacy_linear'
,
...
...
@@ -630,6 +631,51 @@ def _apply_activation(y, activation_fn, output_collections):
return
y
def
repeat
(
inputs
,
repetitions
,
layer
,
*
args
,
**
kwargs
):
"""Applies the same layer with the same arguments repeatedly.
```python
y = repeat(x, 3, conv2d, 64, [3, 3], scope='conv1')
# It is equivalent to:
x = conv2d(x, 64, [3, 3], scope='conv1/conv1_1')
x = conv2d(x, 64, [3, 3], scope='conv1/conv1_2')
y = conv2d(x, 64, [3, 3], scope='conv1/conv1_3')
```
If the `scope` argument is not given in `kwargs`, it is set to
`layer.__name__`, or `layer.func.__name__` (for `functools.partial`
objects). If neither `__name__` nor `func.__name__` is available, the
layers are called with `scope='stack'`.
Args:
inputs: A `Tensor` suitable for layer.
repetitions: Int, number of repetitions.
layer: A layer with arguments `(inputs, *args, **kwargs)`
*args: Extra args for the layer.
**kwargs: Extra kwargs for the layer.
Returns:
a tensor result of applying the layer, repetitions times.
Raises:
ValueError: if the op is unknown or wrong.
"""
scope
=
kwargs
.
pop
(
'scope'
,
None
)
with
variable_scope
.
variable_op_scope
([
inputs
],
scope
,
'Repeat'
):
outputs
=
inputs
if
scope
is
None
:
if
hasattr
(
layer
,
'__name__'
):
scope
=
layer
.
__name__
elif
hasattr
(
layer
,
'func'
)
and
hasattr
(
layer
.
func
,
'__name__'
):
scope
=
layer
.
func
.
__name__
# In case layer is a functools.partial.
else
:
scope
=
'repeat'
for
i
in
range
(
repetitions
):
kwargs
[
'scope'
]
=
scope
+
'_'
+
str
(
i
+
1
)
outputs
=
layer
(
outputs
,
*
args
,
**
kwargs
)
return
outputs
def
stack
(
inputs
,
layer
,
stack_args
,
**
kwargs
):
"""Builds a stack of layers by applying layer repeatedly using stack_args.
...
...
@@ -638,15 +684,15 @@ def stack(inputs, layer, stack_args, **kwargs):
a new scope appended with an increasing number. For example:
```python
stack(x, fully_connected, [32, 64, 128], scope='fc')
y =
stack(x, fully_connected, [32, 64, 128], scope='fc')
# It is equivalent to:
x = fully_connected(x, 32, scope='fc/fc_1')
x = fully_connected(x, 64, scope='fc/fc_2')
x
= fully_connected(x, 128, scope='fc/fc_3')
y
= fully_connected(x, 128, scope='fc/fc_3')
```
If the `scope` argument is not given in `
stack_
args`, it is set to
If the `scope` argument is not given in `
kw
args`, it is set to
`layer.__name__`, or `layer.func.__name__` (for `functools.partial`
objects). If neither `__name__` nor `func.__name__` is available, the
layers are called with `scope='stack'`.
...
...
@@ -668,7 +714,6 @@ def stack(inputs, layer, stack_args, **kwargs):
raise
ValueError
(
'stack_args need to be a list or tuple'
)
with
variable_scope
.
variable_op_scope
([
inputs
],
scope
,
'Stack'
):
outputs
=
inputs
scope
=
scope
if
scope
is
None
:
if
hasattr
(
layer
,
'__name__'
):
scope
=
layer
.
__name__
...
...
tensorflow/contrib/layers/python/layers/layers_test.py
浏览文件 @
6a365818
...
...
@@ -897,6 +897,28 @@ class OneHotEncodingTest(tf.test.TestCase):
self
.
assertAllClose
(
output
.
eval
(),
one_hot_labels
.
eval
())
class
RepeatTests
(
tf
.
test
.
TestCase
):
def
testRepeat
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
,
name
=
'images'
)
output
=
tf
.
contrib
.
layers
.
repeat
(
images
,
3
,
tf
.
contrib
.
layers
.
conv2d
,
32
,
[
3
,
3
])
self
.
assertEquals
(
output
.
op
.
name
,
'Repeat/convolution2d_3/Relu'
)
self
.
assertListEqual
(
output
.
get_shape
().
as_list
(),
[
5
,
3
,
3
,
32
])
def
testRepeatWithScope
(
self
):
height
,
width
=
3
,
3
with
self
.
test_session
():
images
=
tf
.
random_uniform
((
5
,
height
,
width
,
3
),
seed
=
1
,
name
=
'images'
)
output
=
tf
.
contrib
.
layers
.
repeat
(
images
,
3
,
tf
.
contrib
.
layers
.
conv2d
,
32
,
[
3
,
3
],
scope
=
'conv1'
)
self
.
assertEquals
(
output
.
op
.
name
,
'conv1/conv1_3/Relu'
)
self
.
assertListEqual
(
output
.
get_shape
().
as_list
(),
[
5
,
3
,
3
,
32
])
class
StackTests
(
tf
.
test
.
TestCase
):
def
testStackFullyConnected
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录