Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
4bb46606
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
4bb46606
编写于
5月 14, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 14, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1149 fix pylint in tests
Merge pull request !1149 from panyifeng/fix_pylint
上级
5075f0a2
755ba75d
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
67 addition
and
92 deletion
+67
-92
tests/ut/cpp/python_input/gtest_input/ir/manager_test.py
tests/ut/cpp/python_input/gtest_input/ir/manager_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/mem_reuse/mem_reuse_test.py
.../cpp/python_input/gtest_input/mem_reuse/mem_reuse_test.py
+3
-5
tests/ut/cpp/python_input/gtest_input/optimizer/ad/__init__.py
.../ut/cpp/python_input/gtest_input/optimizer/ad/__init__.py
+2
-3
tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py
...s/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py
+5
-8
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
+4
-4
tests/ut/cpp/python_input/gtest_input/pipeline/infer/__init__.py
...t/cpp/python_input/gtest_input/pipeline/infer/__init__.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pipeline/infer/infer_test.py
...cpp/python_input/gtest_input/pipeline/infer/infer_test.py
+1
-1
tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py
...python_input/gtest_input/pipeline/infer/primitive_test.py
+9
-9
tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_test.py
...pp/python_input/gtest_input/pipeline/parse/parser_test.py
+5
-5
tests/ut/cpp/python_input/gtest_input/session/session_test.py
...s/ut/cpp/python_input/gtest_input/session/session_test.py
+7
-7
tests/ut/python/pipeline/infer/test_range.py
tests/ut/python/pipeline/infer/test_range.py
+1
-1
tests/ut/python/pipeline/parse/test_dtype.py
tests/ut/python/pipeline/parse/test_dtype.py
+0
-1
tests/ut/python/pipeline/parse/test_graph_return_const_param.py
...ut/python/pipeline/parse/test_graph_return_const_param.py
+0
-2
tests/ut/python/pipeline/parse/test_operator.py
tests/ut/python/pipeline/parse/test_operator.py
+0
-3
tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py
tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py
+2
-2
tests/ut/python/pynative_mode/nn/test_cell.py
tests/ut/python/pynative_mode/nn/test_cell.py
+0
-1
tests/ut/python/pynative_mode/nn/test_dropout.py
tests/ut/python/pynative_mode/nn/test_dropout.py
+0
-1
tests/ut/python/pynative_mode/nn/test_pooling.py
tests/ut/python/pynative_mode/nn/test_pooling.py
+0
-3
tests/ut/python/pynative_mode/ops/test_grad.py
tests/ut/python/pynative_mode/ops/test_grad.py
+3
-4
tests/ut/python/pynative_mode/ops/test_hypermap.py
tests/ut/python/pynative_mode/ops/test_hypermap.py
+0
-1
tests/ut/python/pynative_mode/test_bprop.py
tests/ut/python/pynative_mode/test_bprop.py
+1
-1
tests/ut/python/pynative_mode/test_cell_bprop.py
tests/ut/python/pynative_mode/test_cell_bprop.py
+3
-3
tests/ut/python/pynative_mode/test_context.py
tests/ut/python/pynative_mode/test_context.py
+2
-2
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+17
-23
未找到文件。
tests/ut/cpp/python_input/gtest_input/ir/manager_test.py
浏览文件 @
4bb46606
...
...
@@ -44,7 +44,7 @@ def test_calls(x):
# pylint: disable=unused-argument
def
test_unused_param
(
x
,
y
):
return
x
*
x
def
test_cannot_replace_return
(
x
):
return
x
*
x
...
...
tests/ut/cpp/python_input/gtest_input/mem_reuse/mem_reuse_test.py
浏览文件 @
4bb46606
...
...
@@ -13,8 +13,6 @@
# limitations under the License.
# ============================================================================
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
Primitive
import
mindspore
as
ms
add
=
P
.
TensorAdd
()
reshape
=
P
.
Reshape
()
...
...
@@ -26,6 +24,6 @@ def test_shape_add(x1, x2, y1, y2, z1, z2):
reshape_sum1
=
reshape
(
sum1
,
(
2
,
2
,
3
,
1
))
reshape_sum2
=
reshape
(
sum2
,
(
2
,
2
,
3
,
1
))
reshape_sum3
=
reshape
(
sum3
,
(
2
,
2
,
3
,
1
))
sum
=
add
(
reshape_sum1
,
reshape_sum2
)
sum
=
add
(
sum
,
reshape_sum3
)
return
sum
result
=
add
(
reshape_sum1
,
reshape_sum2
)
result
=
add
(
result
,
reshape_sum3
)
return
result
tests/ut/cpp/python_input/gtest_input/optimizer/ad/__init__.py
浏览文件 @
4bb46606
"""
@File : __init__.py
@Author:
@Date : 2019-01-23 16:36
@Desc :
@Date : 2019-01-23 16:36
@Desc :
"""
from
.ad_test
import
*
tests/ut/cpp/python_input/gtest_input/optimizer/ad/ad_test.py
浏览文件 @
4bb46606
...
...
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
dataclasses
import
dataclass
import
numpy
as
np
import
mindspore
as
ms
from
dataclasses
import
dataclass
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
Primitive
from
mindspore.model_zoo.resnet
import
resnet50
...
...
@@ -106,7 +106,7 @@ def test_closure(a):
def
x1
(
b
):
def
x4
(
c
):
return
b
return
c
*
b
return
x4
x2
=
x1
(
a
)
x3
=
x2
(
1.0
)
...
...
@@ -117,21 +117,18 @@ def test_if(a, b):
# if statement, so I prefer to name the test 'test_if'
if
a
>
b
:
return
a
else
:
return
b
return
b
def
test_if2
(
a
,
b
):
if
a
>
b
:
return
a
*
a
else
:
return
b
+
b
return
b
+
b
def
test_fact
(
x
):
def
fact
(
n
):
if
n
<=
1
:
return
1
else
:
return
n
*
fact
(
n
-
1
)
return
n
*
fact
(
n
-
1
)
return
fact
(
x
)
def
test_while
(
x
):
...
...
tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py
浏览文件 @
4bb46606
...
...
@@ -13,11 +13,11 @@
# limitations under the License.
# ============================================================================
""" opt_test """
import
numpy
as
np
from
mindspore.ops
import
Primitive
,
PrimitiveWithInfer
from
mindspore.ops
import
operations
as
P
from
mindspore.ops.operations
import
_grad_ops
as
G
from
mindspore
import
Tensor
import
numpy
as
np
# pylint: disable=unused-variable
...
...
@@ -790,9 +790,9 @@ def test_convert_switch_ops(tag):
return
z
@
fns
def
after
(
cond
,
x
,
y
):
sw1
=
ge_switch
(
x
,
cond
)
sw2
=
ge_switch
(
y
,
cond
)
sw3
=
ge_switch
(
y
,
cond
)
sw1
=
ge_switch
(
x
,
cond
)
sw2
=
ge_switch
(
y
,
cond
)
sw3
=
ge_switch
(
y
,
cond
)
sw1_t
=
tuple_getitem
(
sw1
,
1
)
sw2_t
=
tuple_getitem
(
sw2
,
1
)
sw3_f
=
tuple_getitem
(
sw3
,
0
)
...
...
tests/ut/cpp/python_input/gtest_input/pipeline/infer/__init__.py
浏览文件 @
4bb46606
...
...
@@ -13,4 +13,4 @@
# limitations under the License.
# ============================================================================
from
.primitive_test
import
*
from
.infer_test
import
*
\ No newline at end of file
from
.infer_test
import
*
tests/ut/cpp/python_input/gtest_input/pipeline/infer/infer_test.py
浏览文件 @
4bb46606
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
mindspore.nn
as
nn
from
dataclasses
import
dataclass
import
mindspore.nn
as
nn
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
...
...
tests/ut/cpp/python_input/gtest_input/pipeline/infer/primitive_test.py
浏览文件 @
4bb46606
...
...
@@ -14,7 +14,7 @@
# ============================================================================
import
mindspore.nn
as
nn
from
mindspore.common
import
dtype
from
mindspore.ops
import
Primitive
,
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops
import
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops
import
operations
as
P
def
get_add
(
a
,
b
):
...
...
@@ -55,15 +55,15 @@ def get_tensor_to_scalar(logits, labels):
conv2d
=
P
.
Conv2D
(
64
,
(
3
,
3
),
pad_mode
=
"pad"
,
pad
=
1
,
stride
=
2
)
(
3
,
3
),
pad_mode
=
"pad"
,
pad
=
1
,
stride
=
2
)
def
get_conv2d
(
x
,
w
):
return
conv2d
(
x
,
w
)
conv2dNative
=
P
.
DepthwiseConv2dNative
(
3
,
(
3
,
3
),
pad_mode
=
"pad"
,
pad
=
1
,
stride
=
2
)
conv2dNative
=
P
.
DepthwiseConv2dNative
(
3
,
(
3
,
3
),
pad_mode
=
"pad"
,
pad
=
1
,
stride
=
2
)
def
get_conv2d_native
(
x
,
w
):
return
conv2dNative
(
x
,
w
)
...
...
@@ -74,8 +74,8 @@ def get_bias_add(x, b):
def
test_conv2d
(
out_channel
,
kernel_size
,
pad
,
stride
,
dilation
):
conv
=
P
.
Conv2D
(
out_channel
=
out_channel
,
kernel_size
=
kernel_size
,
pad_mode
=
"pad"
,
pad
=
pad
,
stride
=
stride
,
dilation
=
dilation
)
conv
=
P
.
Conv2D
(
out_channel
=
out_channel
,
kernel_size
=
kernel_size
,
pad_mode
=
"pad"
,
pad
=
pad
,
stride
=
stride
,
dilation
=
dilation
)
def
get_conv
(
x
,
w
):
return
conv
(
x
,
w
)
return
get_conv
...
...
@@ -83,7 +83,7 @@ def test_conv2d(out_channel, kernel_size, pad, stride, dilation):
def
test_dropout
():
dropOutGenMask
=
P
.
DropoutGenMask
()
dropoutDoMask
=
P
.
DropoutDoMask
()
dropoutDoMask
=
P
.
DropoutDoMask
()
shape
=
P
.
Shape
()
def
get_dropout
(
x
,
prob
):
mask
=
dropOutGenMask
(
shape
(
x
),
prob
)
...
...
tests/ut/cpp/python_input/gtest_input/pipeline/parse/parser_test.py
浏览文件 @
4bb46606
...
...
@@ -154,12 +154,12 @@ def test_lambda(x, y):
return
t
def
test_funcdef
(
x
,
y
):
def
max
(
a
,
b
):
def
m
ym
ax
(
a
,
b
):
if
a
>
b
:
return
a
else
:
return
b
t
=
max
(
x
,
y
)
t
=
m
ym
ax
(
x
,
y
)
return
t
def
test_tuple_fn
(
x
,
y
):
...
...
@@ -225,7 +225,7 @@ def test_simple_closure(a, b):
return
b
+
2.0
return
f
()
*
g
()
def
test_assign_tuple
(
x
,
y
):
def
test_assign_tuple
(
x
,
y
):
a
=
1
b
=
2
t
=
a
,
b
...
...
@@ -282,8 +282,8 @@ def test_subscript_setitem():
return
t
def
test_dict
():
dic
t
=
{
"a"
:
1
,
"b"
:
2
}
return
dic
t
re
t
=
{
"a"
:
1
,
"b"
:
2
}
return
re
t
def
func_call
(
x
,
y
,
*
var
,
a
=
0
,
b
=
1
,
**
kwargs
):
return
x
+
y
+
var
[
0
]
+
a
+
b
+
kwargs
[
"z"
]
...
...
tests/ut/cpp/python_input/gtest_input/session/session_test.py
浏览文件 @
4bb46606
...
...
@@ -25,13 +25,13 @@ tuple_getitem = Primitive('tuple_getitem')
max_pool
=
P
.
MaxPoolWithArgmax
(
padding
=
"same"
,
ksize
=
3
,
strides
=
2
)
def
test_addn_cast
(
x
,
y
,
z
):
sum
=
addn
((
x
,
y
))
res
=
cast
(
sum
,
ms
.
float16
)
my
sum
=
addn
((
x
,
y
))
res
=
cast
(
my
sum
,
ms
.
float16
)
return
res
def
test_addn_with_max_pool
(
x
,
y
):
sum
=
addn
((
x
,
y
))
output
=
max_pool
(
sum
)
my
sum
=
addn
((
x
,
y
))
output
=
max_pool
(
my
sum
)
res
=
tuple_getitem
(
output
,
0
)
return
res
...
...
@@ -43,6 +43,6 @@ def test_shape_add(x1, x2, y1, y2, z1, z2):
reshape_sum1
=
reshape
(
sum1
,
(
2
,
2
,
3
,
1
))
reshape_sum2
=
reshape
(
sum2
,
(
2
,
2
,
3
,
1
))
reshape_sum3
=
reshape
(
sum3
,
(
2
,
2
,
3
,
1
))
sum
=
add
(
reshape_sum1
,
reshape_sum2
)
sum
=
add
(
sum
,
reshape_sum3
)
return
sum
my
sum
=
add
(
reshape_sum1
,
reshape_sum2
)
mysum
=
add
(
my
sum
,
reshape_sum3
)
return
my
sum
tests/ut/python/pipeline/infer/test_range.py
浏览文件 @
4bb46606
...
...
@@ -12,9 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from
mindspore.common.api
import
ms_function
import
numpy
as
np
from
mindspore
import
Tensor
from
mindspore.common.api
import
ms_function
from
mindspore.ops
import
operations
as
P
...
...
tests/ut/python/pipeline/parse/test_dtype.py
浏览文件 @
4bb46606
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
# ============================================================================
""" test_dtype """
import
pytest
from
mindspore._c_expression
import
typing
from
mindspore.common.api
import
ms_function
...
...
tests/ut/python/pipeline/parse/test_graph_return_const_param.py
浏览文件 @
4bb46606
...
...
@@ -19,8 +19,6 @@ import mindspore.nn as nn
from
mindspore
import
context
import
mindspore.common.dtype
as
mstype
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
from
mindspore.common.api
import
ms_function
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
...
...
tests/ut/python/pipeline/parse/test_operator.py
浏览文件 @
4bb46606
...
...
@@ -200,6 +200,3 @@ def test_in_dict():
z
=
Tensor
(
np
.
random
.
randint
(
low
=
20
,
high
=
30
,
size
=
(
2
,
3
,
4
),
dtype
=
np
.
int32
))
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
net
(
x
,
y
,
z
)
tests/ut/python/pynative_mode/ge/ops/test_tensor_add.py
浏览文件 @
4bb46606
...
...
@@ -38,5 +38,5 @@ def test_tensor_orign_ops():
assert
np
.
all
(
z
.
asnumpy
()
-
(
x
.
asnumpy
()
+
y
.
asnumpy
())
<
0.0001
)
z
=
x
*
y
assert
np
.
all
(
z
.
asnumpy
()
-
(
x
.
asnumpy
()
*
y
.
asnumpy
())
<
0.0001
)
assert
(
x
==
y
)
assert
(
x
!=
'zero'
)
assert
x
==
y
assert
x
!=
'zero'
tests/ut/python/pynative_mode/nn/test_cell.py
浏览文件 @
4bb46606
...
...
@@ -297,4 +297,3 @@ def test_net_call():
input_x
=
Tensor
(
np
.
random
.
randint
(
0
,
255
,
[
1
,
3
,
net
.
image_h
,
net
.
image_w
]).
astype
(
np
.
float32
))
output
=
net
.
construct
(
input_x
)
tests/ut/python/pynative_mode/nn/test_dropout.py
浏览文件 @
4bb46606
...
...
@@ -14,7 +14,6 @@
# ============================================================================
""" test_dropout """
import
numpy
as
np
import
pytest
from
mindspore.common.api
import
_executor
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
...
...
tests/ut/python/pynative_mode/nn/test_pooling.py
浏览文件 @
4bb46606
...
...
@@ -57,6 +57,3 @@ def test_maxpool2d():
output
=
max_pool
(
input_data
)
output_np
=
output
.
asnumpy
()
assert
isinstance
(
output_np
[
0
][
0
][
0
][
0
],
(
np
.
float32
,
np
.
float64
))
tests/ut/python/pynative_mode/ops/test_grad.py
浏览文件 @
4bb46606
...
...
@@ -20,7 +20,6 @@ from mindspore import Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops.composite
import
grad_all_with_sens
from
mindspore.common.dtype
import
get_py_obj_dtype
import
mindspore.nn
as
nn
import
mindspore.ops.operations
as
P
from
mindspore.ops
import
functional
as
F
from
...ut_filter
import
non_graph_engine
...
...
@@ -174,7 +173,7 @@ def test_select_grad():
assert
np
.
all
(
gout
[
0
].
asnumpy
()
==
expect_cond
)
assert
np
.
all
(
gout
[
1
].
asnumpy
()
==
expect_x
)
assert
np
.
all
(
gout
[
2
].
asnumpy
()
==
expect_y
)
def
test_SubGrad
():
""" test_SubGrad """
...
...
@@ -201,10 +200,10 @@ def test_MulGrad():
""" test_MulGrad """
input_x
=
Tensor
(
np
.
array
([[
2
,
2
],
[
2
,
2
]],
np
.
float32
))
input_y
=
Tensor
(
np
.
array
([[
3
,
3
],
[
3
,
3
]],
np
.
float32
))
mul
=
P
.
Mul
()
m
ym
ul
=
P
.
Mul
()
def
fn
(
x
,
y
):
output
=
mul
(
x
,
y
)
output
=
m
ym
ul
(
x
,
y
)
return
output
out
=
fn
(
input_x
,
input_y
)
...
...
tests/ut/python/pynative_mode/ops/test_hypermap.py
浏览文件 @
4bb46606
...
...
@@ -17,7 +17,6 @@ import numpy as np
from
mindspore.common.api
import
ms_function
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.ops
import
Primitive
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
...
...
tests/ut/python/pynative_mode/test_bprop.py
浏览文件 @
4bb46606
...
...
@@ -19,8 +19,8 @@ from mindspore import context
from
mindspore.ops
import
operations
as
P
from
mindspore.common.parameter
import
Parameter
from
mindspore.common
import
Tensor
from
....mindspore_test_framework.utils.bprop_util
import
bprop
from
mindspore.common.api
import
ms_function
from
....mindspore_test_framework.utils.bprop_util
import
bprop
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
...
...
tests/ut/python/pynative_mode/test_cell_bprop.py
浏览文件 @
4bb46606
...
...
@@ -88,7 +88,7 @@ class WithNoBprop(nn.Cell):
def
test_with_no_bprop
():
with_no_bprop
=
WithNoBprop
()
C
.
grad_all
(
with_no_bprop
)(
1
,
2
)
==
(
2
,
1
)
assert
C
.
grad_all
(
with_no_bprop
)(
1
,
2
)
==
(
2
,
1
)
def
test_grad_in_bprop_1
():
class
GradInBprop_1
(
nn
.
Cell
):
...
...
@@ -189,8 +189,8 @@ class OneInputBprop(nn.Cell):
def
test_grad_one_input_bprop
():
net
=
OneInputBprop
()
input
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
grad
=
C
.
grad_all
(
net
)(
input
)
input
1
=
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
))
grad
=
C
.
grad_all
(
net
)(
input
1
)
assert
(
grad
[
0
].
asnumpy
()
==
np
.
array
([
5
,
5
]).
astype
(
np
.
float32
)).
all
()
...
...
tests/ut/python/pynative_mode/test_context.py
浏览文件 @
4bb46606
...
...
@@ -68,9 +68,9 @@ def test_dump_target():
with
pytest
.
raises
(
TypeError
):
context
.
set_context
(
save_dump_path
=
1
)
context
.
set_context
(
enable_dump
=
False
)
assert
context
.
get_context
(
"enable_dump"
)
==
False
assert
not
context
.
get_context
(
"enable_dump"
)
context
.
set_context
(
enable_dump
=
True
)
assert
context
.
get_context
(
"enable_dump"
)
==
True
assert
context
.
get_context
(
"enable_dump"
)
assert
context
.
get_context
(
"save_dump_path"
)
==
"."
...
...
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
4bb46606
...
...
@@ -15,23 +15,17 @@
""" test_framstruct """
import
pytest
import
numpy
as
np
import
mindspore
as
ms
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.parameter
import
Parameter
,
ParameterTuple
from
mindspore.common.initializer
import
initializer
from
mindspore.common
import
dtype
as
mstype
import
mindspore.nn
as
nn
from
mindspore.nn.wrap.cell_wrapper
import
WithGradCell
,
WithLossCell
from
..ut_filter
import
non_graph_engine
from
....mindspore_test_framework.utils.check_gradient
import
(
ms_function
,
check_jacobian
,
Tensor
,
NNGradChecker
,
OperationGradChecker
,
check_gradient
,
ScalarGradChecker
)
from
....mindspore_test_framework.utils.bprop_util
import
bprop
import
mindspore.context
as
context
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
...
...
@@ -299,22 +293,22 @@ def test_dont_unroll_while():
assert
res
==
3
class
ConvNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
ConvNet
,
self
).
__init__
()
out_channel
=
16
kernel_size
=
3
self
.
conv
=
P
.
Conv2D
(
out_channel
,
kernel_size
,
mode
=
1
,
pad_mode
=
"pad"
,
pad
=
0
,
stride
=
1
,
dilation
=
2
,
group
=
1
)
self
.
w
=
Parameter
(
Tensor
(
np
.
ones
([
16
,
16
,
3
,
3
]).
astype
(
np
.
float32
)),
name
=
'w'
)
def
construct
(
self
,
x
):
return
self
.
conv
(
x
,
self
.
w
)
def
__init__
(
self
):
super
(
ConvNet
,
self
).
__init__
()
out_channel
=
16
kernel_size
=
3
self
.
conv
=
P
.
Conv2D
(
out_channel
,
kernel_size
,
mode
=
1
,
pad_mode
=
"pad"
,
pad
=
0
,
stride
=
1
,
dilation
=
2
,
group
=
1
)
self
.
w
=
Parameter
(
Tensor
(
np
.
ones
([
16
,
16
,
3
,
3
]).
astype
(
np
.
float32
)),
name
=
'w'
)
def
construct
(
self
,
x
):
return
self
.
conv
(
x
,
self
.
w
)
conv
=
ConvNet
()
c1
=
Tensor
([
2
],
mstype
.
float32
)
...
...
@@ -674,7 +668,7 @@ def grad_refactor_6(a, b):
def
test_grad_refactor_6
():
C
.
grad_all
(
grad_refactor_6
)(
3
,
2
)
==
(
3
,
1
)
assert
C
.
grad_all
(
grad_refactor_6
)(
3
,
2
)
==
(
3
,
1
)
def
grad_refactor_while
(
x
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录