Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
add3778a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
add3778a
编写于
7月 06, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add grad all in pynative mode
上级
f201bd65
变更
16
隐藏空白更改
内联
并排
Showing
16 changed file
with
307 addition
and
194 deletion
+307
-194
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+1
-1
mindspore/common/tensor.py
mindspore/common/tensor.py
+3
-0
mindspore/context.py
mindspore/context.py
+4
-1
mindspore/nn/cell.py
mindspore/nn/cell.py
+4
-0
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+34
-21
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
tests/st/ops/gpu/test_dense_op.py
tests/st/ops/gpu/test_dense_op.py
+1
-0
tests/ut/python/pipeline/infer/test_net_infer.py
tests/ut/python/pipeline/infer/test_net_infer.py
+1
-0
tests/ut/python/pipeline/parse/test_cell_bprop.py
tests/ut/python/pipeline/parse/test_cell_bprop.py
+15
-5
tests/ut/python/pipeline/parse/test_parse.py
tests/ut/python/pipeline/parse/test_parse.py
+117
-1
tests/ut/python/pynative_mode/nn/test_tensor_operation.py
tests/ut/python/pynative_mode/nn/test_tensor_operation.py
+6
-0
tests/ut/python/pynative_mode/ops/test_grad.py
tests/ut/python/pynative_mode/ops/test_grad.py
+24
-20
tests/ut/python/pynative_mode/test_framstruct.py
tests/ut/python/pynative_mode/test_framstruct.py
+48
-134
tests/ut/python/pynative_mode/test_hook.py
tests/ut/python/pynative_mode/test_hook.py
+33
-7
tests/ut/python/pynative_mode/test_insert_grad_of.py
tests/ut/python/pynative_mode/test_insert_grad_of.py
+2
-0
tests/ut/python/pynative_mode/test_stop_gradient.py
tests/ut/python/pynative_mode/test_stop_gradient.py
+13
-4
未找到文件。
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
add3778a
...
@@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
...
@@ -980,7 +980,7 @@ std::vector<AnfNodePtr> PynativeExecutor::GetWeightsArgs(const py::object &weigh
}
}
}
}
}
else
{
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"training not paramter_tuple"
;
MS_LOG
(
DEBUG
)
<<
"training not paramter_tuple"
;
}
}
return
w_args
;
return
w_args
;
}
}
...
...
mindspore/common/tensor.py
浏览文件 @
add3778a
...
@@ -181,6 +181,9 @@ class Tensor(Tensor_):
...
@@ -181,6 +181,9 @@ class Tensor(Tensor_):
def
__imod__
(
self
,
other
):
def
__imod__
(
self
,
other
):
return
self
.
__mod__
(
other
)
return
self
.
__mod__
(
other
)
def
__pow__
(
self
,
other
):
return
tensor_operator_registry
.
get
(
'__pow__'
)(
self
,
other
)
def
__floordiv__
(
self
,
other
):
def
__floordiv__
(
self
,
other
):
return
tensor_operator_registry
.
get
(
'__floordiv__'
)(
self
,
other
)
return
tensor_operator_registry
.
get
(
'__floordiv__'
)(
self
,
other
)
...
...
mindspore/context.py
浏览文件 @
add3778a
...
@@ -176,7 +176,10 @@ class _Context:
...
@@ -176,7 +176,10 @@ class _Context:
self
.
_context_switches
.
push
(
True
,
None
)
self
.
_context_switches
.
push
(
True
,
None
)
else
:
else
:
if
self
.
enable_debug_runtime
:
if
self
.
enable_debug_runtime
:
self
.
set_backend_policy
(
"ge"
)
if
self
.
device_target
==
"CPU"
:
self
.
set_backend_policy
(
"vm"
)
else
:
self
.
set_backend_policy
(
"ge"
)
self
.
_context_switches
.
push
(
False
,
None
)
self
.
_context_switches
.
push
(
False
,
None
)
def
set_backend_policy
(
self
,
policy
):
def
set_backend_policy
(
self
,
policy
):
...
...
mindspore/nn/cell.py
浏览文件 @
add3778a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
time
import
time
import
gc
import
gc
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
numpy
from
mindspore
import
log
as
logger
from
mindspore
import
log
as
logger
from
..
import
context
from
..
import
context
from
..common
import
dtype
as
mstype
from
..common
import
dtype
as
mstype
...
@@ -211,6 +212,9 @@ class Cell:
...
@@ -211,6 +212,9 @@ class Cell:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
:
out
=
self
.
compile_and_run
(
*
inputs
)
out
=
self
.
compile_and_run
(
*
inputs
)
return
out
return
out
for
item
in
inputs
:
if
isinstance
(
item
,
numpy
.
ndarray
):
raise
TypeError
(
"cell inputs should not be numpy array."
)
self
.
init_parameters_data
()
self
.
init_parameters_data
()
orign_grad
=
[]
orign_grad
=
[]
if
self
.
requires_grad
is
True
:
if
self
.
requires_grad
is
True
:
...
...
mindspore/ops/composite/base.py
浏览文件 @
add3778a
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
"""Basic composite operations."""
"""Basic composite operations."""
from
functools
import
partial
from
functools
import
partial
from
types
import
FunctionType
from
mindspore
import
context
from
mindspore
import
context
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
Map_
,
MultitypeFuncGraph_
,
Tail_
,
\
from
..._c_expression
import
EnvInstance_
,
GradOperation_
,
HyperMap_
,
Map_
,
MultitypeFuncGraph_
,
Tail_
,
\
...
@@ -25,6 +26,7 @@ from ...common import dtype as mstype
...
@@ -25,6 +26,7 @@ from ...common import dtype as mstype
from
...common.api
import
ms_function
,
_pynative_exec
,
_wrap_func
from
...common.api
import
ms_function
,
_pynative_exec
,
_wrap_func
from
..
import
functional
as
F
from
..
import
functional
as
F
from
...common.parameter
import
Parameter
from
...common.parameter
import
Parameter
from
...common.tensor
import
Tensor
__all__
=
[
EnvInstance_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
__all__
=
[
EnvInstance_
,
TupleAdd_
,
TupleSlice_
,
UnpackCall_
,
TupleGetItemTensor_
]
...
@@ -114,37 +116,48 @@ class GradOperation(GradOperation_):
...
@@ -114,37 +116,48 @@ class GradOperation(GradOperation_):
self
.
fn
=
None
self
.
fn
=
None
self
.
need_forward
=
False
self
.
need_forward
=
False
def
_pynative_forward_run
(
self
,
args
,
fn
):
""" Pynative forward run to build grad graph. """
if
self
.
sens_param
:
args
=
args
[:
-
1
]
if
isinstance
(
fn
,
FunctionType
):
_pynative_exec
.
set_grad_flag
(
True
)
_pynative_exec
.
new_graph
(
fn
,
*
args
)
output
=
fn
(
*
args
)
_pynative_exec
.
end_graph
(
fn
,
output
,
*
args
)
else
:
if
fn
.
is_run
and
not
fn
.
requires_grad
:
raise
ValueError
(
"obj must set_grad."
)
if
not
fn
.
is_run
:
self
.
need_forward
=
True
print
(
"already has forward run before grad by user"
)
if
self
.
need_forward
:
fn
.
set_grad
()
fn
(
*
args
)
def
__call__
(
self
,
fn
,
weights
=
None
):
def
__call__
(
self
,
fn
,
weights
=
None
):
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
grad_
=
GradOperation
(
'grad'
,
self
.
get_all
,
self
.
get_by_list
,
self
.
sens_param
)
if
self
.
grad_fn
is
None
or
self
.
fn
!=
fn
:
if
self
.
grad_fn
is
None
or
self
.
fn
!=
fn
:
if
self
.
get_by_list
:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
:
if
context
.
get_context
(
"mode"
)
==
context
.
GRAPH_MODE
:
if
self
.
get_by_list
:
@
ms_function
(
obj
=
fn
)
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
def
after_grad
(
*
args
):
return
grad_
(
fn
,
weights
)(
*
args
)
return
grad_
(
fn
,
weights
)(
*
args
)
else
:
else
:
@
_wrap_func
@
ms_function
(
obj
=
fn
)
def
after_grad
(
*
args
):
def
after_grad
(
*
args
):
if
fn
.
is_run
and
not
fn
.
requires_grad
:
return
grad_
(
fn
)(
*
args
)
raise
ValueError
(
"obj must set_grad."
)
if
not
fn
.
is_run
:
self
.
need_forward
=
True
print
(
"already has forward run before grad by user"
)
if
self
.
need_forward
:
fn
.
set_grad
()
if
self
.
sens_param
:
f_args
=
args
[:
-
1
]
fn
(
*
f_args
)
else
:
fn
(
*
args
)
_pynative_exec
.
grad
(
grad_
,
fn
,
weights
,
*
args
)
out
=
_pynative_exec
(
*
args
)
_pynative_exec
.
clear
()
return
out
else
:
else
:
@
ms_function
(
obj
=
fn
)
@
_wrap_func
def
after_grad
(
*
args
):
def
after_grad
(
*
args
):
return
grad_
(
fn
)(
*
args
)
for
arg
in
args
:
if
not
isinstance
(
arg
,
Tensor
):
raise
TypeError
(
"grad inputs should be tensor in pynative mode"
)
self
.
_pynative_forward_run
(
args
,
fn
)
_pynative_exec
.
grad
(
grad_
,
fn
,
weights
,
*
args
)
out
=
_pynative_exec
(
*
args
)
_pynative_exec
.
clear
()
return
out
self
.
grad_fn
=
after_grad
self
.
grad_fn
=
after_grad
self
.
fn
=
fn
self
.
fn
=
fn
return
self
.
grad_fn
return
self
.
grad_fn
...
...
mindspore/ops/functional.py
浏览文件 @
add3778a
...
@@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub)
...
@@ -166,6 +166,7 @@ tensor_operator_registry.register('__sub__', tensor_sub)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__mul__'
,
tensor_mul
)
tensor_operator_registry
.
register
(
'__truediv__'
,
tensor_div
)
tensor_operator_registry
.
register
(
'__truediv__'
,
tensor_div
)
tensor_operator_registry
.
register
(
'__mod__'
,
tensor_mod
)
tensor_operator_registry
.
register
(
'__mod__'
,
tensor_mod
)
tensor_operator_registry
.
register
(
'__pow__'
,
tensor_pow
)
tensor_operator_registry
.
register
(
'__floordiv__'
,
tensor_floordiv
)
tensor_operator_registry
.
register
(
'__floordiv__'
,
tensor_floordiv
)
#ms cannot support Tensor(True) compare
#ms cannot support Tensor(True) compare
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
tensor_operator_registry
.
register
(
'__eq__'
,
equal
)
...
...
tests/st/ops/gpu/test_dense_op.py
浏览文件 @
add3778a
...
@@ -228,6 +228,7 @@ def test_biasadd_3d():
...
@@ -228,6 +228,7 @@ def test_biasadd_3d():
error
=
np
.
ones
(
shape
=
[
3
,
4
,
8
])
*
1.0e-6
error
=
np
.
ones
(
shape
=
[
3
,
4
,
8
])
*
1.0e-6
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
net
=
BiasAdd
()
net
=
BiasAdd
()
net
.
set_grad
()
result
=
net
(
x
,
b
)
result
=
net
(
x
,
b
)
diff
=
result
.
asnumpy
()
-
expect
diff
=
result
.
asnumpy
()
-
expect
assert
np
.
all
(
diff
<
error
)
assert
np
.
all
(
diff
<
error
)
...
...
tests/ut/python/pipeline/infer/test_net_infer.py
浏览文件 @
add3778a
...
@@ -45,6 +45,7 @@ def test_net_infer():
...
@@ -45,6 +45,7 @@ def test_net_infer():
def
test_assign_in_while
():
def
test_assign_in_while
():
context
.
set_context
(
device_target
=
"Ascend"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
class
Net
(
nn
.
Cell
):
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
input_shape
):
def
__init__
(
self
,
input_shape
):
...
...
tests/ut/python/p
ynative_mod
e/test_cell_bprop.py
→
tests/ut/python/p
ipeline/pars
e/test_cell_bprop.py
浏览文件 @
add3778a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
mindspore
as
ms
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Parameter
from
mindspore
import
Parameter
...
@@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer
...
@@ -24,12 +25,15 @@ from mindspore.common.initializer import initializer
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
....mindspore_test_framework.utils.bprop_util
import
bprop
from
....
.
mindspore_test_framework.utils.bprop_util
import
bprop
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
context
.
set_context
(
device_target
=
"CPU"
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
def
teardown_module
(
module
):
context
.
set_context
(
device_target
=
"Ascend"
)
class
MulAdd
(
nn
.
Cell
):
class
MulAdd
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
...
@@ -45,7 +49,9 @@ class MulAdd(nn.Cell):
...
@@ -45,7 +49,9 @@ class MulAdd(nn.Cell):
def
test_grad_mul_add
():
def
test_grad_mul_add
():
mul_add
=
MulAdd
()
mul_add
=
MulAdd
()
assert
C
.
grad_all
(
mul_add
)(
1
,
2
)
==
(
2
,
4
)
x
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
y
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
assert
C
.
grad_all
(
mul_add
)(
x
,
y
)
==
(
2
,
4
)
class
InlineMulADD
(
nn
.
Cell
):
class
InlineMulADD
(
nn
.
Cell
):
...
@@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell):
...
@@ -60,7 +66,9 @@ class InlineMulADD(nn.Cell):
def
test_grad_inline_mul_add
():
def
test_grad_inline_mul_add
():
inline_mul_add
=
InlineMulADD
()
inline_mul_add
=
InlineMulADD
()
assert
C
.
grad_all
(
inline_mul_add
)(
1
,
2
)
==
(
3
,
6
)
x
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
y
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
assert
C
.
grad_all
(
inline_mul_add
)(
x
,
y
)
==
(
3
,
6
)
class
WithParameter
(
nn
.
Cell
):
class
WithParameter
(
nn
.
Cell
):
...
@@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell):
...
@@ -93,7 +101,9 @@ class WithNoBprop(nn.Cell):
def
test_with_no_bprop
():
def
test_with_no_bprop
():
with_no_bprop
=
WithNoBprop
()
with_no_bprop
=
WithNoBprop
()
assert
C
.
grad_all
(
with_no_bprop
)(
1
,
2
)
==
(
2
,
1
)
x
=
Tensor
(
1
,
dtype
=
ms
.
int32
)
y
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
assert
C
.
grad_all
(
with_no_bprop
)(
x
,
y
)
==
(
2
,
1
)
def
test_grad_in_bprop_1
():
def
test_grad_in_bprop_1
():
...
...
tests/ut/python/pipeline/parse/test_parse.py
浏览文件 @
add3778a
...
@@ -19,21 +19,27 @@
...
@@ -19,21 +19,27 @@
@Desc :
@Desc :
"""
"""
import
logging
import
logging
import
pytest
import
numpy
as
np
import
numpy
as
np
import
mindspore
as
ms
import
mindspore
as
ms
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore.ops
import
composite
as
C
from
mindspore.common.api
import
ms_function
,
_executor
from
mindspore.common.api
import
ms_function
,
_executor
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
from
mindspore.ops.functional
import
tensor_add
from
mindspore.ops.functional
import
tensor_add
from
...ut_filter
import
non_graph_engine
from
...ut_filter
import
non_graph_engine
# pylint: disable=W0613
# pylint: disable=W0613
,W0612
# W0613: unused-argument
# W0613: unused-argument
log
=
logging
.
getLogger
(
"test"
)
log
=
logging
.
getLogger
(
"test"
)
log
.
setLevel
(
level
=
logging
.
ERROR
)
log
.
setLevel
(
level
=
logging
.
ERROR
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
# Test case: use the parse obj interface use default parameter
# Test case: use the parse obj interface use default parameter
...
@@ -135,3 +141,113 @@ def test_net_with_ndarray():
...
@@ -135,3 +141,113 @@ def test_net_with_ndarray():
input_data
=
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
)
input_data
=
np
.
array
([[
1.2
,
2.1
],
[
2.2
,
3.2
]]).
astype
(
'float32'
)
net
(
ms
.
Tensor
(
input_data
))
net
(
ms
.
Tensor
(
input_data
))
def
test_bprop_with_wrong_output_num
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputNum
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputNum
,
self
).
__init__
(
'BpropWithWrongOutputNum'
)
def
__call__
(
self
,
x
,
y
):
return
x
def
infer_shape
(
self
,
x_shape
,
yshape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
,
y_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputNum
)
def
get_bprop_with_wrong_output_num
(
self
):
"""Generate bprop for BpropWithWrongOutputNum"""
def
bprop
(
x
,
y
,
out
,
dout
):
return
(
dout
,)
return
bprop
class
BpropWithWrongOutputNumCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputNumCell
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
BpropWithWrongOutputNum
()(
x
,
y
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputNumCell
())(
1
,
2
)
def
test_bprop_with_wrong_output_type
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputType
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputType
,
self
).
__init__
(
'BpropWithWrongOutputType'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputType
)
def
get_bprop_with_wrong_output_type
(
self
):
"""Generate bprop for BpropWithWrongOutputType"""
def
bprop
(
x
,
out
,
dout
):
return
(
1
,)
return
bprop
class
BpropWithWrongOutputTypeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputTypeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputType
()(
x
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputTypeCell
())(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
def
test_bprop_with_wrong_output_shape
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputShape
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputShape
,
self
).
__init__
(
'BpropWithWrongOutputShape'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputShape
)
def
get_bprop_with_wrong_output_shape
(
self
):
"""Generate bprop for BpropWithWrongOutputShape"""
ones
=
Tensor
(
np
.
ones
([
2
,]).
astype
(
np
.
int32
))
def
bprop
(
x
,
out
,
dout
):
return
(
ones
,)
return
bprop
class
BpropWithWrongOutputShapeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputShapeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputShape
()(
x
)
with
pytest
.
raises
(
TypeError
):
net
=
BpropWithWrongOutputShapeCell
()
net
.
set_grad
()
C
.
grad_all
(
net
)(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
tests/ut/python/pynative_mode/nn/test_tensor_operation.py
浏览文件 @
add3778a
...
@@ -78,3 +78,9 @@ def test_tensor_imul():
...
@@ -78,3 +78,9 @@ def test_tensor_imul():
y
=
Tensor
(
np
.
ones
([
3
,
3
,
3
,
3
]).
astype
(
np
.
float32
))
y
=
Tensor
(
np
.
ones
([
3
,
3
,
3
,
3
]).
astype
(
np
.
float32
))
x
*=
y
x
*=
y
assert
x
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1.0
assert
x
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
1.0
def
test_tensor_pow
():
x
=
Tensor
(
np
.
ones
([
3
,
3
,
3
,
3
]).
astype
(
np
.
float32
)
*
2
)
y
=
x
**
3
assert
y
.
asnumpy
()[
0
][
0
][
0
][
0
]
==
8.0
tests/ut/python/pynative_mode/ops/test_grad.py
浏览文件 @
add3778a
...
@@ -89,7 +89,11 @@ def test_scalar_cast_grad():
...
@@ -89,7 +89,11 @@ def test_scalar_cast_grad():
output
=
F
.
scalar_cast
(
x
,
input_t
)
output
=
F
.
scalar_cast
(
x
,
input_t
)
return
output
return
output
gfn
=
C
.
grad
(
fx_cast
)(
input_x
)
@
ms_function
def
grad_fx_cast
(
input_x
):
return
C
.
grad
(
fx_cast
)(
input_x
)
gfn
=
grad_fx_cast
(
input_x
)
expect_dx
=
1
expect_dx
=
1
assert
gfn
==
expect_dx
assert
gfn
==
expect_dx
...
@@ -133,25 +137,6 @@ def test_transpose_grad():
...
@@ -133,25 +137,6 @@ def test_transpose_grad():
assert
np
.
all
(
gout
[
0
].
asnumpy
()
==
expect
)
assert
np
.
all
(
gout
[
0
].
asnumpy
()
==
expect
)
@
non_graph_engine
def
test_squeeze_grad
():
""" test_squeeze_grad """
input_tensor
=
Tensor
(
np
.
ones
(
shape
=
[
3
,
2
,
1
]))
squeeze
=
P
.
Squeeze
(
2
)
def
fn
(
x
):
output
=
squeeze
(
x
)
return
output
out
=
fn
(
input_tensor
)
gfn
=
grad_all_with_sens
(
fn
)
sens
=
Tensor
(
np
.
ones_like
(
out
.
asnumpy
()))
args
=
[
input_tensor
,
sens
]
gout
=
gfn
(
*
args
)
expect
=
np
.
ones
([
3
,
2
,
1
])
assert
np
.
all
(
gout
[
0
].
asnumpy
()
==
expect
)
def
test_select_grad
():
def
test_select_grad
():
""" test_select_grad """
""" test_select_grad """
select
=
P
.
Select
()
select
=
P
.
Select
()
...
@@ -176,6 +161,25 @@ def test_select_grad():
...
@@ -176,6 +161,25 @@ def test_select_grad():
assert
np
.
all
(
gout
[
2
].
asnumpy
()
==
expect_y
)
assert
np
.
all
(
gout
[
2
].
asnumpy
()
==
expect_y
)
@
non_graph_engine
def
test_squeeze_grad
():
""" test_squeeze_grad """
input_tensor
=
Tensor
(
np
.
ones
(
shape
=
[
3
,
2
,
1
]))
squeeze
=
P
.
Squeeze
(
2
)
def
fn
(
x
):
output
=
squeeze
(
x
)
return
output
out
=
fn
(
input_tensor
)
gfn
=
grad_all_with_sens
(
fn
)
sens
=
Tensor
(
np
.
ones_like
(
out
.
asnumpy
()))
args
=
[
input_tensor
,
sens
]
gout
=
gfn
(
*
args
)
expect
=
np
.
ones
([
3
,
2
,
1
])
assert
np
.
all
(
gout
[
0
].
asnumpy
()
==
expect
)
def
test_SubGrad
():
def
test_SubGrad
():
""" test_SubGrad """
""" test_SubGrad """
input_x
=
Tensor
(
np
.
array
([[
2
,
2
]]))
input_x
=
Tensor
(
np
.
array
([[
2
,
2
]]))
...
...
tests/ut/python/pynative_mode/test_framstruct.py
浏览文件 @
add3778a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
mindspore
as
ms
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
context
from
mindspore
import
context
from
mindspore.common
import
dtype
as
mstype
from
mindspore.common
import
dtype
as
mstype
...
@@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple
...
@@ -23,8 +24,6 @@ from mindspore.common.parameter import Parameter, ParameterTuple
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops._grad.grad_base
import
bprop_getters
from
mindspore.ops.primitive
import
prim_attr_register
,
PrimitiveWithInfer
from
..ut_filter
import
non_graph_engine
from
..ut_filter
import
non_graph_engine
from
....mindspore_test_framework.utils.check_gradient
import
(
from
....mindspore_test_framework.utils.check_gradient
import
(
ms_function
,
check_jacobian
,
Tensor
,
NNGradChecker
,
ms_function
,
check_jacobian
,
Tensor
,
NNGradChecker
,
...
@@ -156,14 +155,14 @@ def test_if_always_true():
...
@@ -156,14 +155,14 @@ def test_if_always_true():
@
non_graph_engine
@
non_graph_engine
def
test_f
():
def
test_f
():
""" test_f """
""" test_f """
res
=
mainf
(
3
,
2
)
res
=
mainf
(
Tensor
(
3
,
dtype
=
ms
.
int32
),
Tensor
(
2
,
dtype
=
ms
.
int32
)
)
assert
res
==
(
2
,
3
)
assert
res
==
(
2
,
3
)
@
non_graph_engine
@
non_graph_engine
def
test_grad_add_mul
():
def
test_grad_add_mul
():
""" test_grad_add_mul """
""" test_grad_add_mul """
res
=
grad_add_mul
(
3
,
2
)
res
=
grad_add_mul
(
Tensor
(
3
,
dtype
=
ms
.
int32
),
Tensor
(
2
,
dtype
=
ms
.
int32
)
)
assert
res
==
(
2
,
7
)
assert
res
==
(
2
,
7
)
...
@@ -262,17 +261,19 @@ def test_if_tensor():
...
@@ -262,17 +261,19 @@ def test_if_tensor():
assert
res
==
Tensor
(
np
.
ones
([
1
]).
astype
(
np
.
int32
)
*
4
)
assert
res
==
Tensor
(
np
.
ones
([
1
]).
astype
(
np
.
int32
)
*
4
)
@
ms_function
def
rec
(
x
):
def
rec
(
x
):
""" rec """
""" rec """
if
x
>
0
:
if
x
>
0
:
return
rec
(
x
-
1
)
return
rec
(
x
-
1
)
return
x
return
x
@
ms_function
def
grad_rec
(
input_x
):
return
C
.
grad
(
rec
)(
input_x
)
def
test_grad_rec
():
def
test_grad_rec
():
""" test_grad_rec """
""" test_grad_rec """
res
=
C
.
grad
(
rec
)(
10
)
res
=
grad_rec
(
3
)
assert
res
==
1
assert
res
==
1
...
@@ -282,7 +283,6 @@ def test_me_rec():
...
@@ -282,7 +283,6 @@ def test_me_rec():
assert
res
==
0
assert
res
==
0
@
ms_function
def
t2_while
(
x
,
y
):
def
t2_while
(
x
,
y
):
out
=
y
-
x
out
=
y
-
x
i
=
0
i
=
0
...
@@ -298,8 +298,10 @@ def test_while2():
...
@@ -298,8 +298,10 @@ def test_while2():
def
test_grad_while2
():
def
test_grad_while2
():
res
=
C
.
grad
(
t2_while
)(
2
,
3
)
@
ms_function
assert
res
==
3
def
df_t2_while
(
input_x
,
input_y
):
return
C
.
grad
(
t2_while
)(
input_x
,
input_y
)
assert
df_t2_while
(
2
,
3
)
==
3
def
if_test
(
a
,
b
):
def
if_test
(
a
,
b
):
...
@@ -316,7 +318,7 @@ def grad_if(x, y):
...
@@ -316,7 +318,7 @@ def grad_if(x, y):
def
test_grad_if
():
def
test_grad_if
():
""" test_grad_if """
""" test_grad_if """
assert
grad_if
(
5
,
4
)
==
(
3
,
0
)
assert
grad_if
(
Tensor
(
5
,
dtype
=
ms
.
int32
),
Tensor
(
4
,
dtype
=
ms
.
int32
)
)
==
(
3
,
0
)
# While loop is not unrolled in forward and backward graphs.
# While loop is not unrolled in forward and backward graphs.
...
@@ -421,7 +423,7 @@ def grad_while(x):
...
@@ -421,7 +423,7 @@ def grad_while(x):
def
test_grad_while
():
def
test_grad_while
():
""" test_grad_while """
""" test_grad_while """
assert
grad_while
(
5
)
==
(
60
,)
assert
grad_while
(
Tensor
(
5
,
dtype
=
ms
.
int32
)
)
==
(
60
,)
@
ms_function
@
ms_function
...
@@ -438,8 +440,10 @@ def test_factorial():
...
@@ -438,8 +440,10 @@ def test_factorial():
def
test_grad_factorial
():
def
test_grad_factorial
():
res
=
C
.
grad
(
factorial
)(
3
)
@
ms_function
assert
res
==
11
def
df_factorial
(
x
):
return
C
.
grad
(
factorial
)(
x
)
assert
df_factorial
(
3
)
==
11
@
ms_function
@
ms_function
...
@@ -513,7 +517,7 @@ def _for(x):
...
@@ -513,7 +517,7 @@ def _for(x):
ret
=
ret
*
i
ret
=
ret
*
i
return
ret
return
ret
@
ms_function
def
grad_for
(
x
):
def
grad_for
(
x
):
""" grad_for """
""" grad_for """
return
C
.
grad_all
(
_for
)(
x
)
return
C
.
grad_all
(
_for
)(
x
)
...
@@ -786,7 +790,10 @@ def multi_outputs(x, y):
...
@@ -786,7 +790,10 @@ def multi_outputs(x, y):
def
test_grad_multi_outputs
():
def
test_grad_multi_outputs
():
assert
C
.
grad_all_with_sens
(
multi_outputs
)(
2
,
3
,
(
1
,
1
))
==
(
4
,
4
)
@
ms_function
def
df_multi_outputs
(
x
,
y
):
return
C
.
grad_all_with_sens
(
multi_outputs
)(
x
,
y
,
(
1
,
1
))
assert
df_multi_outputs
(
2
,
3
)
==
(
4
,
4
)
@
ms_function
@
ms_function
...
@@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y):
...
@@ -813,7 +820,7 @@ def grad_refactor_simple_1(x, y):
def
test_grad_refactor_simple_1
():
def
test_grad_refactor_simple_1
():
assert
C
.
grad_all
(
grad_refactor_simple_1
)(
2
,
1
)
==
(
4
,
2
)
assert
C
.
grad_all
(
grad_refactor_simple_1
)(
Tensor
(
2
,
dtype
=
ms
.
int32
),
Tensor
(
1
,
dtype
=
ms
.
int32
)
)
==
(
4
,
2
)
def
grad_refactor_simple_2
(
x
,
y
,
z
):
def
grad_refactor_simple_2
(
x
,
y
,
z
):
...
@@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z):
...
@@ -822,7 +829,10 @@ def grad_refactor_simple_2(x, y, z):
def
test_grad_refactor_simple_2
():
def
test_grad_refactor_simple_2
():
assert
C
.
grad_all
(
grad_refactor_simple_2
)(
2
,
3
,
0
)
==
(
7
,
4
,
7
)
x
=
Tensor
(
2
,
dtype
=
ms
.
int32
)
y
=
Tensor
(
3
,
dtype
=
ms
.
int32
)
z
=
Tensor
(
0
,
dtype
=
ms
.
int32
)
assert
C
.
grad_all
(
grad_refactor_simple_2
)(
x
,
y
,
z
)
==
(
7
,
4
,
7
)
def
grad_refactor_1
(
a
,
b
):
def
grad_refactor_1
(
a
,
b
):
...
@@ -835,7 +845,7 @@ def grad_refactor_1(a, b):
...
@@ -835,7 +845,7 @@ def grad_refactor_1(a, b):
def
test_grad_refactor_1
():
def
test_grad_refactor_1
():
assert
C
.
grad_all
(
grad_refactor_1
)(
2
,
3
)
==
(
3
,
2
)
assert
C
.
grad_all
(
grad_refactor_1
)(
Tensor
(
2
,
dtype
=
ms
.
int32
),
Tensor
(
3
,
dtype
=
ms
.
int32
)
)
==
(
3
,
2
)
def
grad_refactor_2
(
a
,
b
):
def
grad_refactor_2
(
a
,
b
):
...
@@ -848,7 +858,7 @@ def grad_refactor_2(a, b):
...
@@ -848,7 +858,7 @@ def grad_refactor_2(a, b):
def
test_grad_refactor_2
():
def
test_grad_refactor_2
():
assert
C
.
grad_all
(
grad_refactor_2
)(
2
,
3
)
==
(
27
,
54
)
assert
C
.
grad_all
(
grad_refactor_2
)(
Tensor
(
2
,
dtype
=
ms
.
int32
),
Tensor
(
3
,
dtype
=
ms
.
int32
)
)
==
(
27
,
54
)
def
grad_refactor_3
(
a
):
def
grad_refactor_3
(
a
):
...
@@ -859,7 +869,10 @@ def grad_refactor_3(a):
...
@@ -859,7 +869,10 @@ def grad_refactor_3(a):
def
test_grad_refactor_3
():
def
test_grad_refactor_3
():
assert
C
.
grad_all
(
grad_refactor_3
)(
3
)
==
(
3
,)
@
ms_function
def
df_refactor_3
(
x
):
return
C
.
grad_all
(
grad_refactor_3
)(
x
)
assert
df_refactor_3
(
3
)
==
(
3
,)
def
grad_refactor_4
(
a
):
def
grad_refactor_4
(
a
):
...
@@ -870,7 +883,7 @@ def grad_refactor_4(a):
...
@@ -870,7 +883,7 @@ def grad_refactor_4(a):
def
test_grad_refactor_4
():
def
test_grad_refactor_4
():
assert
C
.
grad_all
(
grad_refactor_4
)(
4
)
==
(
3
,)
assert
C
.
grad_all
(
grad_refactor_4
)(
Tensor
(
4
,
dtype
=
ms
.
int32
)
)
==
(
3
,)
def
grad_refactor_5
(
a
):
def
grad_refactor_5
(
a
):
...
@@ -881,7 +894,10 @@ def grad_refactor_5(a):
...
@@ -881,7 +894,10 @@ def grad_refactor_5(a):
def
test_grad_refactor_5
():
def
test_grad_refactor_5
():
assert
C
.
grad_all
(
grad_refactor_5
)(
1
)
==
(
1
,)
@
ms_function
def
df_refactor_5
(
x
):
return
C
.
grad_all
(
grad_refactor_5
)(
x
)
assert
df_refactor_5
(
1
)
==
(
1
,)
def
grad_refactor_6
(
a
,
b
):
def
grad_refactor_6
(
a
,
b
):
...
@@ -892,7 +908,7 @@ def grad_refactor_6(a, b):
...
@@ -892,7 +908,7 @@ def grad_refactor_6(a, b):
def
test_grad_refactor_6
():
def
test_grad_refactor_6
():
assert
C
.
grad_all
(
grad_refactor_6
)(
3
,
2
)
==
(
3
,
1
)
assert
C
.
grad_all
(
grad_refactor_6
)(
Tensor
(
3
,
dtype
=
ms
.
int32
),
Tensor
(
2
,
dtype
=
ms
.
int32
)
)
==
(
3
,
1
)
def
grad_refactor_while
(
x
):
def
grad_refactor_while
(
x
):
...
@@ -904,7 +920,10 @@ def grad_refactor_while(x):
...
@@ -904,7 +920,10 @@ def grad_refactor_while(x):
def
test_grad_refactor_9
():
def
test_grad_refactor_9
():
assert
C
.
grad_all
(
grad_refactor_while
)(
3
)
==
(
6
,)
@
ms_function
def
df_refactor_while
(
input_x
):
return
C
.
grad_all
(
grad_refactor_while
)(
input_x
)
assert
df_refactor_while
(
3
)
==
(
6
,)
def
grad_refactor__while_1
(
x
):
def
grad_refactor__while_1
(
x
):
...
@@ -919,7 +938,7 @@ def grad_refactor__while_1(x):
...
@@ -919,7 +938,7 @@ def grad_refactor__while_1(x):
def
test_grad_refactor_10
():
def
test_grad_refactor_10
():
""" test_grad_while """
""" test_grad_while """
assert
C
.
grad_all
(
grad_refactor__while_1
)(
5
)
==
(
60
,)
assert
C
.
grad_all
(
grad_refactor__while_1
)(
Tensor
(
5
,
dtype
=
ms
.
int32
)
)
==
(
60
,)
def
test_grad_refactor_11
():
def
test_grad_refactor_11
():
...
@@ -985,7 +1004,10 @@ def grad_refactor_14(a, b):
...
@@ -985,7 +1004,10 @@ def grad_refactor_14(a, b):
def
test_grad_refactor_14
():
def
test_grad_refactor_14
():
assert
C
.
grad_all
(
grad_refactor_14
)(
2
,
3
)
==
(
3
,
9
)
@
ms_function
def
df_refactor_14
(
x
,
y
):
return
C
.
grad_all
(
grad_refactor_14
)(
x
,
y
)
assert
df_refactor_14
(
2
,
3
)
==
(
3
,
9
)
# pylint: disable=using-constant-test
# pylint: disable=using-constant-test
...
@@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline():
...
@@ -1009,111 +1031,3 @@ def test_grad_if_defer_inline():
inp
=
Tensor
(
np
.
ones
([
128
,
96
]).
astype
(
np
.
float32
))
inp
=
Tensor
(
np
.
ones
([
128
,
96
]).
astype
(
np
.
float32
))
grads
=
C
.
grad_all
(
network
)(
inp
)
grads
=
C
.
grad_all
(
network
)(
inp
)
assert
grads
==
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),)
assert
grads
==
(
Tensor
(
np
.
full
([
128
,
96
],
0.6
,
dtype
=
np
.
float32
)),)
def
test_bprop_with_wrong_output_num
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputNum
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputNum
,
self
).
__init__
(
'BpropWithWrongOutputNum'
)
def
__call__
(
self
,
x
,
y
):
return
x
def
infer_shape
(
self
,
x_shape
,
yshape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
,
y_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputNum
)
def
get_bprop_with_wrong_output_num
(
self
):
"""Generate bprop for BpropWithWrongOutputNum"""
def
bprop
(
x
,
y
,
out
,
dout
):
return
(
dout
,)
return
bprop
class
BpropWithWrongOutputNumCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputNumCell
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
BpropWithWrongOutputNum
()(
x
,
y
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputNumCell
())(
1
,
2
)
def
test_bprop_with_wrong_output_type
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputType
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputType
,
self
).
__init__
(
'BpropWithWrongOutputType'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputType
)
def
get_bprop_with_wrong_output_type
(
self
):
"""Generate bprop for BpropWithWrongOutputType"""
def
bprop
(
x
,
out
,
dout
):
return
(
1
,)
return
bprop
class
BpropWithWrongOutputTypeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputTypeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputType
()(
x
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputTypeCell
())(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
def
test_bprop_with_wrong_output_shape
():
context
.
set_context
(
check_bprop
=
True
)
class
BpropWithWrongOutputShape
(
PrimitiveWithInfer
):
@
prim_attr_register
def
__init__
(
self
):
super
(
BpropWithWrongOutputShape
,
self
).
__init__
(
'BpropWithWrongOutputShape'
)
def
__call__
(
self
,
x
):
return
x
def
infer_shape
(
self
,
x_shape
):
return
x_shape
def
infer_dtype
(
self
,
x_type
):
return
x_type
@
bprop_getters
.
register
(
BpropWithWrongOutputShape
)
def
get_bprop_with_wrong_output_shape
(
self
):
"""Generate bprop for BpropWithWrongOutputShape"""
ones
=
Tensor
(
np
.
ones
([
2
,]).
astype
(
np
.
int32
))
def
bprop
(
x
,
out
,
dout
):
return
(
ones
,)
return
bprop
class
BpropWithWrongOutputShapeCell
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
BpropWithWrongOutputShapeCell
,
self
).
__init__
()
def
construct
(
self
,
x
):
return
BpropWithWrongOutputShape
()(
x
)
with
pytest
.
raises
(
TypeError
):
C
.
grad_all
(
BpropWithWrongOutputShapeCell
())(
Tensor
(
np
.
ones
([
64
,
10
]).
astype
(
np
.
int32
)))
tests/ut/python/pynative_mode/test_hook.py
浏览文件 @
add3778a
...
@@ -13,6 +13,7 @@
...
@@ -13,6 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
import
numpy
as
np
import
numpy
as
np
import
pytest
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
import
mindspore.ops.operations
as
P
import
mindspore.ops.operations
as
P
...
@@ -154,22 +155,47 @@ def test_hook():
...
@@ -154,22 +155,47 @@ def test_hook():
print
(
loss_output
.
asnumpy
().
shape
)
print
(
loss_output
.
asnumpy
().
shape
)
bprop_debug
=
False
class
MulAdd
(
nn
.
Cell
):
class
MulAdd
(
nn
.
Cell
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
MulAdd
,
self
).
__init__
()
super
(
MulAdd
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
def
construct
(
self
,
x
,
y
):
return
2
*
x
+
y
return
2
*
x
*
x
+
y
*
y
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
def
bprop
(
self
,
x
,
y
,
out
,
dout
):
assert
(
x
==
1
)
global
bprop_debug
assert
(
y
==
2
)
bprop_debug
=
True
assert
(
out
==
4
)
return
dout
,
2
*
y
assert
(
dout
==
1
)
return
3
*
dout
,
2
*
y
def
test_custom_bprop
():
def
test_custom_bprop
():
mul_add
=
MulAdd
()
mul_add
=
MulAdd
()
mul_add
.
bprop_debug
=
True
mul_add
.
bprop_debug
=
True
assert
C
.
grad_all
(
mul_add
)(
1
,
2
)
==
(
3
,
4
)
x
=
Tensor
(
np
.
array
([
1
,
2
,
3
]).
astype
(
np
.
int32
))
y
=
Tensor
(
np
.
array
([
2
,
3
,
4
]).
astype
(
np
.
int32
))
C
.
grad_all
(
mul_add
)(
x
,
y
)
assert
bprop_debug
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
def
construct
(
self
,
x
,
y
):
return
2
*
x
*
x
+
y
*
y
def
test_grad_all
():
net
=
Net
()
x
=
Tensor
(
np
.
array
([
1
,
2
,
3
]).
astype
(
np
.
int32
))
y
=
Tensor
(
np
.
array
([
2
,
3
,
4
]).
astype
(
np
.
int32
))
res
=
C
.
grad_all
(
net
)(
x
,
y
)
print
(
res
)
def
test_check_input
():
net
=
Net
()
x
=
np
.
array
([
1
,
2
,
3
])
y
=
np
.
array
([
2
,
3
,
4
])
with
pytest
.
raises
(
TypeError
):
net
(
x
,
y
)
tests/ut/python/pynative_mode/test_insert_grad_of.py
浏览文件 @
add3778a
...
@@ -46,6 +46,7 @@ def test_InsertGradientOf_1():
...
@@ -46,6 +46,7 @@ def test_InsertGradientOf_1():
c
=
x
*
y
c
=
x
*
y
return
c
return
c
@
ms_function
def
f
(
x
,
y
):
def
f
(
x
,
y
):
return
C
.
grad_all
(
stop_test
)(
x
,
y
)
return
C
.
grad_all
(
stop_test
)(
x
,
y
)
...
@@ -80,6 +81,7 @@ def test_InsertGradientOf_2():
...
@@ -80,6 +81,7 @@ def test_InsertGradientOf_2():
def
f
(
x
,
y
):
def
f
(
x
,
y
):
return
clip_test
(
x
,
y
)
return
clip_test
(
x
,
y
)
@
ms_function
def
fd
(
x
,
y
):
def
fd
(
x
,
y
):
return
C
.
grad_all
(
clip_test
)(
x
,
y
)
return
C
.
grad_all
(
clip_test
)(
x
,
y
)
...
...
tests/ut/python/pynative_mode/test_stop_gradient.py
浏览文件 @
add3778a
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
import
mindspore
as
ms
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore
import
Parameter
,
ParameterTuple
from
mindspore
import
Parameter
,
ParameterTuple
...
@@ -81,16 +82,24 @@ def stop_test4(x, y):
...
@@ -81,16 +82,24 @@ def stop_test4(x, y):
return
e
return
e
@
ms_function
def
grad_stop_test
(
x
,
y
):
def
grad_stop_test
(
x
,
y
):
""" grad_stop_test """
""" grad_stop_test """
return
C
.
grad_all
(
stop_test2
)(
x
,
y
)
return
C
.
grad_all
(
stop_test2
)(
x
,
y
)
@
ms_function
def
grad_stop_test1
(
x
,
y
):
def
grad_stop_test1
(
x
,
y
):
""" grad_stop_test1 """
""" grad_stop_test1 """
return
C
.
grad_all
(
stop_test3
)(
x
,
y
)
return
C
.
grad_all
(
stop_test3
)(
x
,
y
)
@
ms_function
def
grad_stop_test5
(
x
,
y
):
""" grad_stop_test5 """
return
C
.
grad_all
(
stop_test5
)(
x
,
y
)
def
test_stop
():
def
test_stop
():
""" test_stop """
""" test_stop """
print
(
"test_stop:"
,
grad_stop_test
(
1
,
1
))
print
(
"test_stop:"
,
grad_stop_test
(
1
,
1
))
...
@@ -103,7 +112,7 @@ def test_stop1():
...
@@ -103,7 +112,7 @@ def test_stop1():
def
test_stop5
():
def
test_stop5
():
""" test_stop1 """
""" test_stop1 """
print
(
"test_stop5:"
,
C
.
grad_all
(
stop_test5
)
(
2
,
3
))
print
(
"test_stop5:"
,
grad_stop_test5
(
2
,
3
))
class
GradWrap
(
nn
.
Cell
):
class
GradWrap
(
nn
.
Cell
):
...
@@ -247,7 +256,7 @@ def test_stop_gradient_4():
...
@@ -247,7 +256,7 @@ def test_stop_gradient_4():
def
stop_test
(
x
):
def
stop_test
(
x
):
return
stop_gradient
(
x
)
return
stop_gradient
(
x
)
assert
C
.
grad_all
(
stop_test
)(
1
)
==
(
0
,)
assert
C
.
grad_all
(
stop_test
)(
Tensor
(
1
,
dtype
=
ms
.
int32
)
)
==
(
0
,)
def
test_stop_gradient_5
():
def
test_stop_gradient_5
():
...
@@ -257,7 +266,7 @@ def test_stop_gradient_5():
...
@@ -257,7 +266,7 @@ def test_stop_gradient_5():
ret
=
x
+
y
ret
=
x
+
y
return
ret
return
ret
assert
C
.
grad_all
(
stop_test
)(
1
)
==
(
1
,)
assert
C
.
grad_all
(
stop_test
)(
Tensor
(
1
,
dtype
=
ms
.
int32
)
)
==
(
1
,)
def
test_stop_gradient_6
():
def
test_stop_gradient_6
():
...
@@ -266,7 +275,7 @@ def test_stop_gradient_6():
...
@@ -266,7 +275,7 @@ def test_stop_gradient_6():
ret
=
stop_gradient
(
ret
)
ret
=
stop_gradient
(
ret
)
return
ret
return
ret
assert
C
.
grad_all
(
stop_test
)(
1
,
3
)
==
(
0
,
0
)
assert
C
.
grad_all
(
stop_test
)(
Tensor
(
1
,
dtype
=
ms
.
int32
),
Tensor
(
3
,
dtype
=
ms
.
int32
)
)
==
(
0
,
0
)
class
PrimWithMultiOutputs
(
PrimitiveWithInfer
):
class
PrimWithMultiOutputs
(
PrimitiveWithInfer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录