Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
兔爷不爱我
mindspore
提交
864622bd
M
mindspore
项目概览
兔爷不爱我
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
864622bd
编写于
4月 23, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 23, 2020
浏览文件
操作
浏览文件
下载
差异文件
!574 Add parameter configuration
Merge pull request !574 from liubuyu/master
上级
a468dc09
672244e0
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
14 addition
and
8 deletion
+14
-8
mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc
...re/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc
+3
-3
mindspore/nn/optim/optimizer.py
mindspore/nn/optim/optimizer.py
+1
-1
mindspore/train/model.py
mindspore/train/model.py
+9
-3
tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py
...on_input/gtest_input/pre_activate/mul_addn_fusion_test.py
+1
-1
未找到文件。
mindspore/ccsrc/pre_activate/ascend/ir_fusion/mul_addn_fusion.cc
浏览文件 @
864622bd
...
@@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
...
@@ -34,7 +34,7 @@ CNodePtr CreateFusionNode(const FuncGraphPtr &graph, const CNodePtr &mul, const
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kFusedMulAddNOpName
);
auto
prim
=
std
::
make_shared
<
Primitive
>
(
kFusedMulAddNOpName
);
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
)};
std
::
vector
<
AnfNodePtr
>
inputs
=
{
NewValueNode
(
prim
)};
inputs
.
push_back
(
mul
->
input
(
kMulInputNum
-
lossscale_input_index
));
inputs
.
push_back
(
mul
->
input
(
kMulInputNum
-
lossscale_input_index
));
inputs
.
push_back
(
addn
->
input
(
1
));
inputs
.
push_back
(
addn
->
input
(
2
));
// scalar input should be 3rd input
// scalar input should be 3rd input
inputs
.
push_back
(
mul
->
input
(
lossscale_input_index
));
inputs
.
push_back
(
mul
->
input
(
lossscale_input_index
));
auto
fusion_node
=
graph
->
NewCNode
(
inputs
);
auto
fusion_node
=
graph
->
NewCNode
(
inputs
);
...
@@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const {
...
@@ -51,7 +51,7 @@ const BaseRef MulAddNFusion::DefinePattern() const {
VarPtr
Z
=
std
::
make_shared
<
Var
>
();
VarPtr
Z
=
std
::
make_shared
<
Var
>
();
VectorRef
mul
({
prim
::
kPrimMul
,
X
,
Z
});
VectorRef
mul
({
prim
::
kPrimMul
,
X
,
Z
});
VectorRef
addn
({
prim
::
kPrimAddN
,
Y
,
mul
});
VectorRef
addn
({
prim
::
kPrimAddN
,
mul
,
Y
});
return
addn
;
return
addn
;
}
}
...
@@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
...
@@ -65,7 +65,7 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
if
(
addn
==
nullptr
||
addn
->
inputs
().
size
()
!=
kAddNInputNum
)
{
if
(
addn
==
nullptr
||
addn
->
inputs
().
size
()
!=
kAddNInputNum
)
{
return
nullptr
;
return
nullptr
;
}
}
auto
mul_anf
=
addn
->
input
(
2
);
auto
mul_anf
=
addn
->
input
(
1
);
if
(
mul_anf
==
nullptr
)
{
if
(
mul_anf
==
nullptr
)
{
return
nullptr
;
return
nullptr
;
}
}
...
...
mindspore/nn/optim/optimizer.py
浏览文件 @
864622bd
...
@@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
...
@@ -177,7 +177,7 @@ apply_decay = C.MultitypeFuncGraph("apply_decay")
def
_tensor_apply_decay
(
weight_decay
,
if_apply
,
weight
,
gradient
):
def
_tensor_apply_decay
(
weight_decay
,
if_apply
,
weight
,
gradient
):
"""Get grad with weight_decay."""
"""Get grad with weight_decay."""
if
if_apply
:
if
if_apply
:
return
op_add
((
gradient
,
weight
*
weight_decay
))
return
op_add
((
weight
*
weight_decay
,
gradient
))
return
gradient
return
gradient
...
...
mindspore/train/model.py
浏览文件 @
864622bd
...
@@ -62,6 +62,7 @@ class Model:
...
@@ -62,6 +62,7 @@ class Model:
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
loss_scale_manager (Union[None, LossScaleManager]): If None, not scale the loss, or else
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
scale the loss by LossScaleManager. If it is set, overwrite the level setting. It's a eyword argument.
e.g. Use `loss_scale_manager=None` to set the value.
e.g. Use `loss_scale_manager=None` to set the value.
keep_batchnorm_fp32 (bool): Keep Batchnorm run in `float32`. If set, overwrite the level setting. Default: True.
Examples:
Examples:
>>> class Net(nn.Cell):
>>> class Net(nn.Cell):
...
@@ -96,7 +97,10 @@ class Model:
...
@@ -96,7 +97,10 @@ class Model:
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_loss_scale_manager
=
None
self
.
_loss_scale_manager
=
None
self
.
_loss_scale_manager_set
=
False
self
.
_loss_scale_manager_set
=
False
self
.
_keep_bn_fp32
=
True
self
.
_check_kwargs
(
kwargs
)
self
.
_check_kwargs
(
kwargs
)
if
'keep_batchnorm_fp32'
in
kwargs
:
self
.
_keep_bn_fp32
=
kwargs
[
'keep_batchnorm_fp32'
]
if
'loss_scale_manager'
in
kwargs
:
if
'loss_scale_manager'
in
kwargs
:
self
.
_loss_scale_manager
=
kwargs
[
'loss_scale_manager'
]
self
.
_loss_scale_manager
=
kwargs
[
'loss_scale_manager'
]
self
.
_loss_scale_manager_set
=
True
self
.
_loss_scale_manager_set
=
True
...
@@ -112,7 +116,7 @@ class Model:
...
@@ -112,7 +116,7 @@ class Model:
def
_check_kwargs
(
self
,
kwargs
):
def
_check_kwargs
(
self
,
kwargs
):
for
arg
in
kwargs
:
for
arg
in
kwargs
:
if
arg
not
in
[
'loss_scale_manager'
]:
if
arg
not
in
[
'loss_scale_manager'
,
'keep_batchnorm_fp32'
]:
raise
ValueError
(
f
"Unsupport arg '
{
arg
}
'"
)
raise
ValueError
(
f
"Unsupport arg '
{
arg
}
'"
)
def
_build_train_network
(
self
):
def
_build_train_network
(
self
):
...
@@ -124,12 +128,14 @@ class Model:
...
@@ -124,12 +128,14 @@ class Model:
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
,
level
=
self
.
_amp_level
,
loss_scale_manager
=
self
.
_loss_scale_manager
)
loss_scale_manager
=
self
.
_loss_scale_manager
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
else
:
else
:
network
=
amp
.
build_train_network
(
network
,
network
=
amp
.
build_train_network
(
network
,
self
.
_optimizer
,
self
.
_optimizer
,
self
.
_loss_fn
,
self
.
_loss_fn
,
level
=
self
.
_amp_level
)
level
=
self
.
_amp_level
,
keep_batchnorm_fp32
=
self
.
_keep_bn_fp32
)
elif
self
.
_loss_fn
:
elif
self
.
_loss_fn
:
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
network
=
nn
.
WithLossCell
(
network
,
self
.
_loss_fn
)
# If need to check if loss_fn is not None, but optimizer is None
# If need to check if loss_fn is not None, but optimizer is None
...
...
tests/ut/cpp/python_input/gtest_input/pre_activate/mul_addn_fusion_test.py
浏览文件 @
864622bd
...
@@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag):
...
@@ -42,7 +42,7 @@ def test_mul_addn_fusion(tag):
@
fns
@
fns
def
before
(
a
,
b
):
def
before
(
a
,
b
):
res
=
mul
(
scalar
,
a
)
res
=
mul
(
scalar
,
a
)
res
=
addn
((
b
,
res
))
res
=
addn
((
res
,
b
))
return
res
return
res
@
fns
@
fns
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录