Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
weixin_40195168达庆意
keras
提交
25a8973d
K
keras
项目概览
weixin_40195168达庆意
/
keras
与 Fork 源项目一致
从无法访问的项目Fork
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
K
keras
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
25a8973d
编写于
5月 23, 2018
作者:
T
Taehoon Lee
提交者:
François Chollet
5月 22, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Increase test coverages by factorizing CNTK pads (#10259)
上级
abd02940
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
24 addition
and
82 deletion
+24
-82
keras/backend/cntk_backend.py
keras/backend/cntk_backend.py
+24
-82
未找到文件。
keras/backend/cntk_backend.py
浏览文件 @
25a8973d
...
...
@@ -2025,20 +2025,8 @@ def function(inputs, outputs, updates=[], **kwargs):
def
temporal_padding
(
x
,
padding
=
(
1
,
1
)):
assert
len
(
padding
)
==
2
num_dynamic_axis
=
_get_dynamic_axis_num
(
x
)
base_shape
=
x
.
shape
if
num_dynamic_axis
>
0
:
assert
len
(
base_shape
)
==
2
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[
padding
,
(
0
,
0
)])
else
:
x
=
_padding
(
x
,
padding
,
0
)
else
:
assert
len
(
base_shape
)
==
3
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[(
0
,
0
),
padding
,
(
0
,
0
)])
else
:
x
=
_padding
(
x
,
padding
,
1
)
return
x
assert
len
(
x
.
shape
)
==
3
-
(
1
if
num_dynamic_axis
>
0
else
0
)
return
pad
(
x
,
[
padding
],
'channels_last'
,
num_dynamic_axis
)
def
_padding
(
x
,
pattern
,
axis
):
...
...
@@ -2062,6 +2050,24 @@ def _padding(x, pattern, axis):
return
x
def
pad
(
x
,
pad_info
,
data_format
,
num_dynamic_axis
):
if
hasattr
(
C
,
'pad'
):
pattern
=
[
list
(
p
)
for
p
in
pad_info
]
if
data_format
==
'channels_first'
:
pattern
=
[[
0
,
0
]]
+
pattern
else
:
pattern
=
pattern
+
[[
0
,
0
]]
if
num_dynamic_axis
==
0
:
pattern
=
[[
0
,
0
]]
+
pattern
return
C
.
pad
(
x
,
pattern
=
pattern
)
else
:
for
(
a
,
p
)
in
enumerate
(
pad_info
):
x
=
_padding
(
x
,
p
,
a
+
(
1
if
num_dynamic_axis
==
0
else
0
)
+
(
1
if
data_format
==
'channels_first'
else
0
))
return
x
def
spatial_2d_padding
(
x
,
padding
=
((
1
,
1
),
(
1
,
1
)),
data_format
=
None
):
assert
len
(
padding
)
==
2
assert
len
(
padding
[
0
])
==
2
...
...
@@ -2072,38 +2078,8 @@ def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
raise
ValueError
(
'Unknown data_format '
+
str
(
data_format
))
num_dynamic_axis
=
_get_dynamic_axis_num
(
x
)
base_shape
=
x
.
shape
if
data_format
==
'channels_first'
:
if
num_dynamic_axis
>
0
:
assert
len
(
base_shape
)
==
3
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
])])
else
:
x
=
_padding
(
x
,
padding
[
0
],
1
)
x
=
_padding
(
x
,
padding
[
1
],
2
)
else
:
assert
len
(
base_shape
)
==
4
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
])])
else
:
x
=
_padding
(
x
,
padding
[
0
],
2
)
x
=
_padding
(
x
,
padding
[
1
],
3
)
else
:
if
num_dynamic_axis
>
0
:
assert
len
(
base_shape
)
==
3
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[
list
(
padding
[
0
]),
list
(
padding
[
1
]),
[
0
,
0
]])
else
:
x
=
_padding
(
x
,
padding
[
0
],
0
)
x
=
_padding
(
x
,
padding
[
1
],
1
)
else
:
assert
len
(
base_shape
)
==
4
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
]),
[
0
,
0
]])
else
:
x
=
_padding
(
x
,
padding
[
0
],
1
)
x
=
_padding
(
x
,
padding
[
1
],
2
)
return
x
assert
len
(
x
.
shape
)
==
4
-
(
1
if
num_dynamic_axis
>
0
else
0
)
return
pad
(
x
,
padding
,
data_format
,
num_dynamic_axis
)
def
spatial_3d_padding
(
x
,
padding
=
((
1
,
1
),
(
1
,
1
),
(
1
,
1
)),
data_format
=
None
):
...
...
@@ -2117,42 +2093,8 @@ def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
raise
ValueError
(
'Unknown data_format '
+
str
(
data_format
))
num_dynamic_axis
=
_get_dynamic_axis_num
(
x
)
base_shape
=
x
.
shape
if
data_format
==
'channels_first'
:
if
num_dynamic_axis
>
0
:
assert
len
(
base_shape
)
==
4
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
]),
list
(
padding
[
2
])])
else
:
x
=
_padding
(
x
,
padding
[
0
],
1
)
x
=
_padding
(
x
,
padding
[
1
],
2
)
x
=
_padding
(
x
,
padding
[
2
],
3
)
else
:
assert
len
(
base_shape
)
==
5
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
]),
list
(
padding
[
2
])])
else
:
x
=
_padding
(
x
,
padding
[
0
],
2
)
x
=
_padding
(
x
,
padding
[
1
],
3
)
x
=
_padding
(
x
,
padding
[
2
],
4
)
else
:
if
num_dynamic_axis
>
0
:
assert
len
(
base_shape
)
==
4
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[
list
(
padding
[
0
]),
list
(
padding
[
1
]),
list
(
padding
[
2
]),
[
0
,
0
]])
else
:
x
=
_padding
(
x
,
padding
[
0
],
0
)
x
=
_padding
(
x
,
padding
[
1
],
1
)
x
=
_padding
(
x
,
padding
[
2
],
2
)
else
:
assert
len
(
base_shape
)
==
5
if
hasattr
(
C
,
'pad'
):
x
=
C
.
pad
(
x
,
pattern
=
[[
0
,
0
],
list
(
padding
[
0
]),
list
(
padding
[
1
]),
list
(
padding
[
2
]),
[
0
,
0
]])
else
:
x
=
_padding
(
x
,
padding
[
0
],
1
)
x
=
_padding
(
x
,
padding
[
1
],
2
)
x
=
_padding
(
x
,
padding
[
2
],
3
)
return
x
assert
len
(
x
.
shape
)
==
5
-
(
1
if
num_dynamic_axis
>
0
else
0
)
return
pad
(
x
,
padding
,
data_format
,
num_dynamic_axis
)
def
one_hot
(
indices
,
num_classes
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录