Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
88e763a9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
88e763a9
编写于
4月 29, 2020
作者:
J
jinyaohui
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify conv2dtranspose
上级
697e8d30
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
30 addition
and
17 deletion
+30
-17
mindspore/nn/layer/conv.py
mindspore/nn/layer/conv.py
+11
-7
tests/ut/python/nn/test_conv.py
tests/ut/python/nn/test_conv.py
+19
-10
未找到文件。
mindspore/nn/layer/conv.py
浏览文件 @
88e763a9
...
...
@@ -37,7 +37,8 @@ class _Conv(Cell):
group
,
has_bias
,
weight_init
,
bias_init
):
bias_init
,
transposed
=
False
):
super
(
_Conv
,
self
).
__init__
()
self
.
in_channels
=
check_int_positive
(
in_channels
)
self
.
out_channels
=
check_int_positive
(
out_channels
)
...
...
@@ -65,9 +66,11 @@ class _Conv(Cell):
if
out_channels
%
group
!=
0
:
raise
ValueError
(
"Attr 'out_channels' of 'Conv2D' Op must be divisible by "
"attr 'group' of 'Conv2D' Op."
)
self
.
weight
=
Parameter
(
initializer
(
weight_init
,
[
out_channels
,
in_channels
//
group
,
*
kernel_size
]),
name
=
'weight'
)
if
transposed
:
shape
=
[
in_channels
,
out_channels
//
group
,
*
kernel_size
]
else
:
shape
=
[
out_channels
,
in_channels
//
group
,
*
kernel_size
]
self
.
weight
=
Parameter
(
initializer
(
weight_init
,
shape
),
name
=
'weight'
)
if
check_bool
(
has_bias
):
self
.
bias
=
Parameter
(
initializer
(
bias_init
,
[
out_channels
]),
name
=
'bias'
)
...
...
@@ -312,8 +315,8 @@ class Conv2dTranspose(_Conv):
# cause Conv2DBackpropInput's out_channel refers to Conv2D's out_channel,
# then Conv2dTranspose's out_channel refers to Conv2DBackpropInput's in_channel.
super
(
Conv2dTranspose
,
self
).
__init__
(
out_channels
,
in_channels
,
out_channels
,
kernel_size
,
stride
,
pad_mode
,
...
...
@@ -322,10 +325,11 @@ class Conv2dTranspose(_Conv):
group
,
has_bias
,
weight_init
,
bias_init
)
bias_init
,
transposed
=
True
)
self
.
out_channels
=
out_channels
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
shape
=
P
.
Shape
()
if
pad_mode
not
in
(
'valid'
,
'same'
,
'pad'
):
raise
ValueError
(
'Attr
\'
pad_mode
\'
of
\'
Conv2dTranspose
\'
Op passed '
...
...
tests/ut/python/nn/test_conv.py
浏览文件 @
88e763a9
...
...
@@ -20,7 +20,6 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
..ut_filter
import
non_graph_engine
weight
=
Tensor
(
np
.
ones
([
2
,
2
]))
in_channels
=
3
out_channels
=
64
...
...
@@ -28,6 +27,7 @@ out_channels = 64
class
Net
(
nn
.
Cell
):
""" Net definition """
def
__init__
(
self
,
cin
,
cout
,
...
...
@@ -93,12 +93,14 @@ def test_compile_pad_pad():
input_data
=
Tensor
(
np
.
ones
([
1
,
3
,
16
,
50
],
dtype
=
np
.
float32
))
net
(
input_data
)
def
test_conv_group_error
():
with
pytest
.
raises
(
ValueError
):
nn
.
Conv2d
(
6
,
8
,
3
,
group
=
3
)
with
pytest
.
raises
(
ValueError
):
nn
.
Conv2d
(
6
,
9
,
3
,
group
=
2
)
def
test_conv_check
():
""" test_conv_check """
with
pytest
.
raises
(
ValueError
):
...
...
@@ -139,15 +141,15 @@ class NetConv2dTranspose(nn.Cell):
super
(
NetConv2dTranspose
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2dTranspose
(
cin
,
cout
,
kernel_size
,
stride
,
pad_mode
,
padding
,
dilation
,
group
,
has_bias
,
weight_init
,
bias_init
)
kernel_size
,
stride
,
pad_mode
,
padding
,
dilation
,
group
,
has_bias
,
weight_init
,
bias_init
)
def
construct
(
self
,
input_x
):
return
self
.
conv
(
input_x
)
...
...
@@ -165,6 +167,13 @@ def test_compile_transpose_bias():
net
(
input_data
)
def
test_compile_transpose_bias_init
():
bias
=
Tensor
(
np
.
random
.
randn
(
64
).
astype
(
np
.
float32
))
net
=
NetConv2dTranspose
(
3
,
64
,
4
,
has_bias
=
True
,
weight_init
=
'normal'
,
bias_init
=
bias
)
input_data
=
Tensor
(
np
.
ones
([
1
,
3
,
16
,
50
],
dtype
=
np
.
float32
))
net
(
input_data
)
def
test_compile_transpose_valid
():
net
=
NetConv2dTranspose
(
3
,
64
,
4
,
pad_mode
=
'valid'
,
weight_init
=
'normal'
)
input_data
=
Tensor
(
np
.
ones
([
1
,
3
,
16
,
50
],
dtype
=
np
.
float32
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录