Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2fecdede
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2fecdede
编写于
4月 09, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support amp when model eval, fix example of UnsortSegmentsSum
上级
c478be0f
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
59 addition
and
78 deletion
+59
-78
mindspore/ccsrc/parallel/step_auto_parallel.cc
mindspore/ccsrc/parallel/step_auto_parallel.cc
+9
-0
mindspore/ops/operations/array_ops.py
mindspore/ops/operations/array_ops.py
+4
-3
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+15
-11
mindspore/train/amp.py
mindspore/train/amp.py
+24
-18
mindspore/train/model.py
mindspore/train/model.py
+5
-3
tests/train_step_wrap.py
tests/train_step_wrap.py
+2
-43
未找到文件。
mindspore/ccsrc/parallel/step_auto_parallel.cc
浏览文件 @
2fecdede
...
...
@@ -636,6 +636,15 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
// Dealing with the RefKey case
auto
refkeys
=
cnode_with_refkeys
.
second
;
auto
cnode
=
cnode_with_refkeys
.
first
;
auto
cnode_ptr
=
cnode
->
cast
<
CNodePtr
>
();
if
(
cnode_ptr
==
nullptr
||
!
IsValueNode
<
Primitive
>
(
cnode_ptr
->
input
(
0
)))
{
continue
;
}
if
(
!
IsAutoParallelCareNode
(
cnode_ptr
))
{
continue
;
}
if
(
refkeys
.
size
()
>
1
)
{
MS_LOG
(
EXCEPTION
)
<<
"CNode: "
<<
cnode
->
fullname_with_scope
()
<<
" 's inputs have more than 1 RefKeys."
;
}
...
...
mindspore/ops/operations/array_ops.py
浏览文件 @
2fecdede
...
...
@@ -1235,10 +1235,11 @@ class UnsortedSegmentSum(PrimitiveWithInfer):
Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`.
Examples:
>>> input_x =
[1, 2, 3, 4]
>>> segment_ids =
[0, 0, 1, 2]
>>> input_x =
Tensor([1, 2, 3, 4], mindspore.float)
>>> segment_ids =
Tensor([0, 0, 1, 2], mindspore.int32)
>>> num_segments = 4
>>> type = P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
>>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments)
[3, 3, 4, 0]
"""
@
prim_attr_register
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
2fecdede
...
...
@@ -22,6 +22,8 @@ from functools import reduce
import
numpy
as
np
from
...
import
context
from
..._c_expression
import
signature_rw
as
sig_rw
from
..._c_expression
import
signature_kind
as
sig_kind
from
..._checkparam
import
ParamValidator
as
validator
from
..._checkparam
import
Rel
,
check_bool
,
check_int_positive
from
...common
import
dtype
as
mstype
...
...
@@ -1297,29 +1299,31 @@ class ApplyMomentum(PrimitiveWithInfer):
filter(lambda x: x.requires_grad, net.get_parameters()))
>>> model = Model(net, loss, opt)
"""
__mindspore_signature__
=
(
(
'variable'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'accumulation'
,
sig_rw
.
RW_WRITE
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'learning_rate'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'gradient'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
),
(
'momentum'
,
sig_rw
.
RW_READ
,
sig_kind
.
KIND_POSITIONAL_KEYWORD
)
)
@
prim_attr_register
def
__init__
(
self
,
use_nesterov
=
False
,
use_locking
=
False
,
gradient_scale
=
1.0
):
self
.
init_prim_io_names
(
inputs
=
[
'variable'
,
'accumulation'
,
'learning_rate'
,
'gradient'
,
'momentum'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
v_shape
,
a_shape
,
l_shape
,
g_shape
,
m_shape
):
validator
.
check
(
f
'variable shape
{
v_shape
}
'
,
len
(
v_shape
),
''
,
0
,
Rel
.
GT
)
validator
.
check
(
f
'accumulation shape
{
a_shape
}
'
,
len
(
a_shape
),
''
,
0
,
Rel
.
GT
)
validator
.
check
(
f
'learning rate shape
{
l_shape
}
'
,
len
(
l_shape
),
''
,
0
,
Rel
.
GE
)
validator
.
check
(
f
'gradient shape
{
g_shape
}
'
,
len
(
g_shape
),
''
,
0
,
Rel
.
GE
)
validator
.
check
(
f
'momentum shape
{
m_shape
}
'
,
len
(
m_shape
),
''
,
0
,
Rel
.
GE
)
return
v_shape
def
infer_dtype
(
self
,
v_dtype
,
a_dtype
,
l_dtype
,
g_dtype
,
m_dtype
):
validator
.
check_subclass
(
"v_dtype"
,
v_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"a_dtype"
,
a_dtype
,
mstype
.
tensor
)
v_type
=
validator
.
check_typename
(
"v_dtype"
,
v_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
validator
.
check_typename
(
"a_dtype"
,
a_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
if
v_dtype
!=
mstype
.
type_refkey
and
a_dtype
!=
mstype
.
type_refkey
:
validator
.
check_subclass
(
"v_dtype"
,
v_dtype
,
mstype
.
tensor
)
validator
.
check_subclass
(
"a_dtype"
,
a_dtype
,
mstype
.
tensor
)
validator
.
check_typename
(
"v_dtype"
,
v_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
validator
.
check_typename
(
"a_dtype"
,
a_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
validator
.
check_typename
(
"l_dtype"
,
l_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
validator
.
check_typename
(
"g_dtype"
,
g_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
validator
.
check_typename
(
"m_dtype"
,
m_dtype
,
[
mstype
.
float16
,
mstype
.
float32
,
mstype
.
float64
])
return
v_
type
return
g_d
type
class
SmoothL1Loss
(
PrimitiveWithInfer
):
...
...
mindspore/train/amp.py
浏览文件 @
2fecdede
...
...
@@ -82,6 +82,29 @@ def _check_kwargs(key_words):
if
loss_scale_manager
:
validator
.
check_isinstance
(
'loss_scale_manager'
,
loss_scale_manager
,
LossScaleManager
)
def
_add_loss_network
(
network
,
loss_fn
,
cast_model_type
):
class
WithLossCell
(
nn
.
Cell
):
"Wrap loss for amp. Cast network output back to float32"
def
__init__
(
self
,
backbone
,
loss_fn
):
super
(
WithLossCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
_backbone
=
backbone
self
.
_loss_fn
=
loss_fn
def
construct
(
self
,
data
,
label
):
out
=
self
.
_backbone
(
data
)
label
=
_mp_cast_helper
(
mstype
.
float32
,
label
)
return
self
.
_loss_fn
(
F
.
cast
(
out
,
mstype
.
float32
),
label
)
validator
.
check_isinstance
(
'loss_fn'
,
loss_fn
,
nn
.
Cell
)
if
cast_model_type
==
mstype
.
float16
:
network
=
WithLossCell
(
network
,
loss_fn
)
else
:
network
=
nn
.
WithLossCell
(
network
,
loss_fn
)
return
network
def
build_train_network
(
network
,
optimizer
,
loss_fn
=
None
,
level
=
'O0'
,
**
kwargs
):
"""
Build the mixed precision training cell automatically.
...
...
@@ -117,24 +140,7 @@ def build_train_network(network, optimizer, loss_fn=None, level='O0', **kwargs):
_do_keep_batchnorm_fp32
(
network
)
if
loss_fn
:
class
WithLossCell
(
nn
.
Cell
):
"Wrap loss for amp. Cast network output back to float32"
def
__init__
(
self
,
backbone
,
loss_fn
):
super
(
WithLossCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
_backbone
=
backbone
self
.
_loss_fn
=
loss_fn
def
construct
(
self
,
data
,
label
):
out
=
self
.
_backbone
(
data
)
label
=
_mp_cast_helper
(
mstype
.
float32
,
label
)
return
self
.
_loss_fn
(
F
.
cast
(
out
,
mstype
.
float32
),
label
)
validator
.
check_isinstance
(
'loss_fn'
,
loss_fn
,
nn
.
Cell
)
if
config
.
cast_model_type
==
mstype
.
float16
:
network
=
WithLossCell
(
network
,
loss_fn
)
else
:
network
=
nn
.
WithLossCell
(
network
,
loss_fn
)
network
=
_add_loss_network
(
network
,
loss_fn
,
config
.
cast_model_type
)
if
_get_parallel_mode
()
in
(
ParallelMode
.
SEMI_AUTO_PARALLEL
,
ParallelMode
.
AUTO_PARALLEL
):
network
=
_VirtualDatasetCell
(
network
)
...
...
mindspore/train/model.py
浏览文件 @
2fecdede
...
...
@@ -24,8 +24,7 @@ from .. import context
from
..parallel._utils
import
_get_parallel_mode
,
_get_device_num
,
_get_global_rank
,
\
_get_parameter_broadcast
,
_device_number_check
,
_parameter_broadcast_check
,
_callback_wrapper
from
..nn.metrics
import
Loss
from
..nn.wrap
import
WithLossCell
,
WithEvalCell
,
\
DataWrapper
from
..nn.wrap
import
WithLossCell
,
DataWrapper
,
WithEvalCell
from
..nn.wrap.cell_wrapper
import
_VirtualDatasetCell
from
.parallel_utils
import
ParallelMode
from
..common
import
dtype
as
mstype
...
...
@@ -151,7 +150,10 @@ class Model:
else
:
if
self
.
_loss_fn
is
None
:
raise
ValueError
(
"loss_fn can not be None."
)
self
.
_eval_network
=
WithEvalCell
(
self
.
_network
,
self
.
_loss_fn
)
if
self
.
_optimizer
:
self
.
_eval_network
=
self
.
_train_network
.
network
else
:
self
.
_eval_network
=
WithEvalCell
(
self
.
_network
,
self
.
_loss_fn
)
self
.
_eval_indexes
=
[
0
,
1
,
2
]
def
_clear_metrics
(
self
):
...
...
tests/train_step_wrap.py
浏览文件 @
2fecdede
...
...
@@ -21,47 +21,6 @@ from mindspore.ops import composite as C
from
mindspore.ops
import
operations
as
P
from
mindspore
import
Parameter
,
ParameterTuple
run_opt
=
C
.
MultitypeFuncGraph
(
"run_opt"
)
# pylint: disable=unused-argument
@
run_opt
.
register
(
"Function"
,
"Int"
,
"Number"
,
"Number"
,
"Tensor"
,
"Tensor"
,
"Tensor"
)
def
tensor_run_opt
(
opt
,
iterator
,
learning_rate
,
momentum
,
gradient
,
variable
,
moment
):
success
=
True
new_weight
=
opt
(
gradient
,
moment
,
variable
,
learning_rate
,
momentum
)
success
=
F
.
depend
(
success
,
P
.
Assign
()(
variable
,
new_weight
))
return
success
class
OptimizerByMomentum
(
nn
.
Cell
):
"""
OptimizerByMomentum definition
"""
# list of tensor
def
__init__
(
self
,
weights
):
super
(
OptimizerByMomentum
,
self
).
__init__
()
self
.
learning_rate
=
Parameter
(
0.1
,
name
=
"learning_rate"
)
self
.
momentum
=
Parameter
(
0.05
,
name
=
"momentum"
)
self
.
iter
=
Parameter
(
0
,
name
=
"iter"
)
self
.
weights
=
weights
self
.
moments
=
weights
.
clone
(
prefix
=
"moments"
,
init
=
'zeros'
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
opt
=
P
.
ApplyMomentum
()
def
construct
(
self
,
grads
):
success
=
True
weights
=
self
.
weights
moments
=
self
.
moments
success
=
self
.
hyper_map
(
F
.
partial
(
run_opt
,
self
.
opt
,
self
.
iter
,
self
.
learning_rate
,
self
.
momentum
),
grads
,
weights
,
moments
)
# self.learning_rate = updata_lr(self.learning_rate, self.momentum)
return
success
class
TrainStepWrap
(
nn
.
Cell
):
"""
TrainStepWrap definition
...
...
@@ -71,7 +30,7 @@ class TrainStepWrap(nn.Cell):
self
.
network
=
network
self
.
network
.
set_train
()
self
.
weights
=
ParameterTuple
(
network
.
trainable_params
())
self
.
optimizer
=
OptimizerByMomentum
(
self
.
weights
)
self
.
optimizer
=
nn
.
Momentum
(
self
.
weights
,
0.1
,
0.9
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
)
...
...
@@ -107,7 +66,7 @@ class TrainStepWrap2(nn.Cell):
self
.
network
=
network
self
.
network
.
set_train
()
self
.
weights
=
ParameterTuple
(
network
.
get_parameters
())
self
.
optimizer
=
OptimizerByMomentum
(
self
.
weights
)
self
.
optimizer
=
nn
.
Momentum
(
self
.
weights
,
0.1
,
0.9
)
self
.
hyper_map
=
C
.
HyperMap
()
self
.
grad
=
C
.
GradOperation
(
'grad'
,
get_by_list
=
True
,
sens_param
=
True
)
self
.
sens
=
sens
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录