Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
37c94a5f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
37c94a5f
编写于
4月 13, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add pass replace_old_param_
上级
a44d7347
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
68 addition
and
6 deletion
+68
-6
mindspore/ccsrc/optimizer/irpass.cc
mindspore/ccsrc/optimizer/irpass.cc
+2
-0
mindspore/ccsrc/optimizer/irpass.h
mindspore/ccsrc/optimizer/irpass.h
+1
-0
mindspore/ccsrc/optimizer/irpass/param_replace.h
mindspore/ccsrc/optimizer/irpass/param_replace.h
+60
-0
mindspore/ccsrc/pipeline/action.cc
mindspore/ccsrc/pipeline/action.cc
+1
-0
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+3
-5
tests/ut/python/pynative_mode/test_insert_grad_of.py
tests/ut/python/pynative_mode/test_insert_grad_of.py
+1
-1
未找到文件。
mindspore/ccsrc/optimizer/irpass.cc
浏览文件 @
37c94a5f
...
...
@@ -40,6 +40,7 @@
#include "optimizer/irpass/incorporate_getitem.h"
#include "optimizer/irpass/incorporate_call.h"
#include "optimizer/irpass/grad_var_prepare.h"
#include "optimizer/irpass/param_replace.h"
namespace
mindspore
{
namespace
opt
{
...
...
@@ -81,6 +82,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
get_make_ref_eliminate_
=
MakeSubstitution
(
GetMakeRefEliminater
(),
"get_make_ref_eliminate"
,
{
prim
::
kPrimGetRefKey
,
prim
::
kPrimGetRefValue
});
replace_refkey_by_param_
=
MakeSubstitution
(
ReplaceRefkeyByParam
(),
"replace_refkey_by_param"
,
IsValueNode
<
RefKey
>
);
replace_old_param_
=
MakeSubstitution
(
ReplaceOldParam
(),
"replace_old_param"
,
IsParam
);
// Gradient transforms
expand_jprim_
=
MakeSubstitution
(
ExpandJPrim
(),
"expand_jprim"
,
prim
::
kPrimJ
);
...
...
mindspore/ccsrc/optimizer/irpass.h
浏览文件 @
37c94a5f
...
...
@@ -58,6 +58,7 @@ class OptimizeIRPassLib {
SubstitutionPtr
make_ref_eliminate_
;
SubstitutionPtr
get_make_ref_eliminate_
;
SubstitutionPtr
replace_refkey_by_param_
;
SubstitutionPtr
replace_old_param_
;
// Branch culling
SubstitutionPtr
switch_simplify_
;
...
...
mindspore/ccsrc/optimizer/irpass/param_replace.h
0 → 100644
浏览文件 @
37c94a5f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_
#define MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_
#include <memory>
#include "optimizer/optimizer.h"
#include "optimizer/irpass.h"
#include "ir/visitor.h"
#include "operator/ops.h"
#include "pipeline/parse/parse.h"
namespace
mindspore
{
namespace
opt
{
namespace
irpass
{
class
ReplaceOldParam
:
public
AnfVisitor
{
public:
AnfNodePtr
operator
()(
const
OptimizerPtr
&
optimizer
,
const
AnfNodePtr
&
node
)
override
{
if
(
!
IsParam
(
node
))
{
return
nullptr
;
}
auto
resource
=
std
::
dynamic_pointer_cast
<
pipeline
::
Resource
>
(
optimizer
->
resource
());
MS_EXCEPTION_IF_NULL
(
resource
);
auto
top_graph
=
resource
->
func_graph
();
// parse::Parser::GetTopFuncGraph();
MS_EXCEPTION_IF_NULL
(
top_graph
);
auto
param_node
=
node
->
cast
<
ParameterPtr
>
();
if
(
!
param_node
->
has_default
()
||
node
->
func_graph
()
==
top_graph
)
{
return
nullptr
;
}
auto
para_name
=
param_node
->
name
();
for
(
const
auto
&
tnode
:
top_graph
->
parameters
())
{
auto
para
=
tnode
->
cast
<
ParameterPtr
>
();
if
(
para
!=
nullptr
&&
para
->
name
()
==
para_name
)
{
return
para
;
}
}
return
nullptr
;
}
};
}
// namespace irpass
}
// namespace opt
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_OPTIMIZER_IRPASS_PARAM_REPLACE_H_
mindspore/ccsrc/pipeline/action.cc
浏览文件 @
37c94a5f
...
...
@@ -88,6 +88,7 @@ FuncGraphPtr Renormalize(const ResourcePtr& res, const FuncGraphPtr& func_graph,
double
t2
=
GetTime
();
#endif
auto
ret
=
ProgramSpecialize
(
res
,
func_graph
,
result
.
context
);
res
->
set_func_graph
(
ret
);
#ifdef ENABLE_PROFILE
double
t3
=
GetTime
();
MsProfile
::
StatTime
(
"renormalize.infer"
,
t2
-
t1
);
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
37c94a5f
...
...
@@ -114,11 +114,9 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
opt
::
OptPassConfig
grad
=
opt
::
OptPassConfig
({
irpass
.
expand_jprim_
},
true
);
opt
::
irpass
::
ResolveIRPassLib
resolve_irpass
;
opt
::
OptPassConfig
resolve_pass
=
opt
::
OptPassConfig
({
resolve_irpass
.
resolver_resolve_
,
resolve_irpass
.
resolver_getattr_
,
irpass
.
get_make_ref_eliminate_
,
});
opt
::
OptPassConfig
resolve_pass
=
opt
::
OptPassConfig
({
resolve_irpass
.
resolver_resolve_
,
resolve_irpass
.
resolver_getattr_
,
irpass
.
get_make_ref_eliminate_
,
irpass
.
replace_old_param_
});
OptPassGroupMap
map_a
({{
"a_1"
,
a_1
},
{
"a_2"
,
a_2
},
...
...
tests/ut/python/pynative_mode/test_insert_grad_of.py
浏览文件 @
37c94a5f
...
...
@@ -129,7 +129,7 @@ def test_cell_assign():
self
.
matrix_g
=
mindspore
.
Parameter
(
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
)),
name
=
"matrix_g"
)
def
save_gradient
(
self
,
dout
):
self
.
matrix_g
=
dout
self
.
matrix_g
=
dout
+
self
.
matrix_g
return
dout
def
construct
(
self
,
x
,
y
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录