Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
113c0d8c
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
113c0d8c
编写于
4月 05, 2020
作者:
W
Wei Luning
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix InsertGradientOf with class method
上级
7a367af9
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
57 addition
and
32 deletion
+57
-32
mindspore/ccsrc/pipeline/parse/resolve.cc
mindspore/ccsrc/pipeline/parse/resolve.cc
+8
-0
mindspore/ccsrc/pipeline/pass.cc
mindspore/ccsrc/pipeline/pass.cc
+8
-0
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+6
-31
tests/ut/python/pynative_mode/test_insert_grad_of.py
tests/ut/python/pynative_mode/test_insert_grad_of.py
+35
-1
未找到文件。
mindspore/ccsrc/pipeline/parse/resolve.cc
浏览文件 @
113c0d8c
...
@@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object&
...
@@ -103,6 +103,14 @@ AnfNodePtr ResolveParameterObj(const FuncGraphPtr& func_graph, const py::object&
if
(
para_node
==
nullptr
)
{
if
(
para_node
==
nullptr
)
{
ParameterPtr
node
=
top_graph
->
AddWeightParameter
(
param_name
);
ParameterPtr
node
=
top_graph
->
AddWeightParameter
(
param_name
);
node
->
set_default_param
(
obj
);
node
->
set_default_param
(
obj
);
// set_abstract for parameter
auto
to_convert
=
py
::
cast
<
py
::
object
>
(
python_adapter
::
GetPyObjAttr
(
obj
,
"default_input"
));
ValuePtr
converted
=
nullptr
;
(
void
)
ConvertData
(
to_convert
,
&
converted
);
bool
broaden
=
true
;
node
->
set_abstract
(
abstract
::
FromValue
(
converted
,
broaden
));
para_node
=
node
;
para_node
=
node
;
}
}
auto
iter
=
func_graph
->
make_ref_params
().
find
(
para_node
);
auto
iter
=
func_graph
->
make_ref_params
().
find
(
para_node
);
...
...
mindspore/ccsrc/pipeline/pass.cc
浏览文件 @
113c0d8c
...
@@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
...
@@ -112,6 +112,13 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
});
});
opt
::
OptPassConfig
virtual_dataset
=
opt
::
OptPassConfig
({
irpass
.
virtual_dataset_eliminate_
});
opt
::
OptPassConfig
virtual_dataset
=
opt
::
OptPassConfig
({
irpass
.
virtual_dataset_eliminate_
});
opt
::
OptPassConfig
grad
=
opt
::
OptPassConfig
({
irpass
.
expand_jprim_
},
true
);
opt
::
OptPassConfig
grad
=
opt
::
OptPassConfig
({
irpass
.
expand_jprim_
},
true
);
opt
::
irpass
::
ResolveIRPassLib
resolve_irpass
;
opt
::
OptPassConfig
resolve_pass
=
opt
::
OptPassConfig
({
resolve_irpass
.
resolver_resolve_
,
resolve_irpass
.
resolver_getattr_
,
irpass
.
get_make_ref_eliminate_
,
});
OptPassGroupMap
map_a
({{
"a_1"
,
a_1
},
OptPassGroupMap
map_a
({{
"a_1"
,
a_1
},
{
"a_2"
,
a_2
},
{
"a_2"
,
a_2
},
...
@@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
...
@@ -120,6 +127,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib& irpass) {
{
"allreduce_fusion"
,
opt
::
OptPassConfig
(
parallel
::
StepAllreduceFusion
)},
{
"allreduce_fusion"
,
opt
::
OptPassConfig
(
parallel
::
StepAllreduceFusion
)},
{
"virtual_dataset"
,
virtual_dataset
},
{
"virtual_dataset"
,
virtual_dataset
},
{
"grad"
,
grad
},
{
"grad"
,
grad
},
{
"resolve"
,
resolve_pass
},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()},
{
"renormalize"
,
opt
::
OptPassConfig
::
Renormalize
()},
{
"cse"
,
opt
::
OptPassConfig
(
opt
::
CSE
(
false
))},
{
"cse"
,
opt
::
OptPassConfig
(
opt
::
CSE
(
false
))},
{
"a_3"
,
a_3
}});
{
"a_3"
,
a_3
}});
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
113c0d8c
...
@@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
...
@@ -554,24 +554,6 @@ AbstractBasePtr StaticGetterInferred(const ValuePtr &value, const ConfigPtr &dat
return
eng
->
ForwardConfig
(
old_conf
,
fn_conf
);
return
eng
->
ForwardConfig
(
old_conf
,
fn_conf
);
}
}
AbstractBasePtr
GenerateResolveAbstract
(
const
AnfNodeConfigPtr
&
out_conf
,
const
py
::
object
&
obj
,
const
ValuePtr
&
converted_ret
)
{
if
(
py
::
hasattr
(
obj
,
PYTHON_DATACLASS_FIELDS
))
{
TypePtr
cls_ptr
=
parse
::
ParseDataClass
(
converted_ret
->
cast
<
std
::
shared_ptr
<
parse
::
PyObjectWrapper
>>
()
->
obj
());
std
::
vector
<
AnfNodePtr
>
input
=
{
NewValueNode
(
prim
::
kPrimPartial
),
NewValueNode
(
prim
::
kPrimMakeRecord
),
NewValueNode
(
cls_ptr
)};
MS_EXCEPTION_IF_NULL
(
out_conf
);
FuncGraphPtr
func_graph
=
out_conf
->
node
()
->
func_graph
();
CNodePtr
new_cnode
=
func_graph
->
NewCNode
(
input
);
AnalysisEnginePtr
eng
=
out_conf
->
engine
();
AnfNodeConfigPtr
fn_conf
=
eng
->
MakeConfig
(
new_cnode
,
out_conf
->
context
());
return
eng
->
ForwardConfig
(
out_conf
,
fn_conf
);
}
else
{
return
ToAbstract
(
converted_ret
,
AnalysisContext
::
DummyContext
(),
out_conf
);
}
}
AbstractBasePtr
GetEvaluatedValueForNameSpaceString
(
const
AnalysisEnginePtr
&
engine
,
AbstractBasePtr
GetEvaluatedValueForNameSpaceString
(
const
AnalysisEnginePtr
&
engine
,
const
AbstractBasePtrList
&
args_spec_list
,
const
AbstractBasePtrList
&
args_spec_list
,
const
AnfNodeConfigPtr
&
out_conf
)
{
const
AnfNodeConfigPtr
&
out_conf
)
{
...
@@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
...
@@ -602,23 +584,16 @@ AbstractBasePtr GetEvaluatedValueForNameSpaceString(const AnalysisEnginePtr &eng
// item_name to func addr from obj_map
// item_name to func addr from obj_map
parse
::
SymbolPtr
symbol
=
item_v
->
cast
<
parse
::
SymbolPtr
>
();
parse
::
SymbolPtr
symbol
=
item_v
->
cast
<
parse
::
SymbolPtr
>
();
parse
::
NameSpacePtr
name_space
=
data_v
->
cast
<
parse
::
NameSpacePtr
>
();
parse
::
NameSpacePtr
name_space
=
data_v
->
cast
<
parse
::
NameSpacePtr
>
();
FuncGraphPtr
func_graph
=
out_conf
->
node
()
->
func_graph
();
parse
::
SymbolResolverPtr
symbol_resolver
=
auto
new_node
=
parse
::
ResolveSymbol
(
func_graph
->
manager
(),
name_space
,
symbol
,
out_conf
->
node
());
std
::
make_shared
<
parse
::
SymbolResolver
>
(
name_space
,
symbol
,
out_conf
->
node
());
if
(
new_node
==
nullptr
)
{
if
(
!
symbol_resolver
->
Resolve
())
{
MS_LOG
(
EXCEPTION
)
<<
"Resolve node failed"
;
MS_LOG
(
EXCEPTION
)
<<
"Resolve node failed"
;
}
}
py
::
object
obj
=
symbol_resolver
->
result
();
AnalysisEnginePtr
eng
=
out_conf
->
engine
();
ValuePtr
converted_ret
=
nullptr
;
AnfNodeConfigPtr
fn_conf
=
eng
->
MakeConfig
(
new_node
,
out_conf
->
context
());
bool
converted
=
parse
::
ConvertData
(
obj
,
&
converted_ret
,
true
);
return
eng
->
ForwardConfig
(
out_conf
,
fn_conf
);
if
(
!
converted
)
{
MS_LOG
(
EXCEPTION
)
<<
"Convert data failed"
;
}
if
(
converted_ret
->
isa
<
FuncGraph
>
())
{
AddToManager
(
engine
,
converted_ret
->
cast
<
FuncGraphPtr
>
());
}
return
GenerateResolveAbstract
(
out_conf
,
obj
,
converted_ret
);
}
}
AbstractBasePtr
GetEvaluatedValueForClassAttrOrMethod
(
const
AnalysisEnginePtr
&
engine
,
AbstractBasePtr
GetEvaluatedValueForClassAttrOrMethod
(
const
AnalysisEnginePtr
&
engine
,
...
...
tests/ut/python/pynative_mode/test_insert_grad_of.py
浏览文件 @
113c0d8c
...
@@ -17,13 +17,14 @@ import numpy as np
...
@@ -17,13 +17,14 @@ import numpy as np
import
mindspore.nn
as
nn
import
mindspore.nn
as
nn
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
functional
as
F
from
mindspore.common.api
import
ms_function
from
mindspore.common.api
import
ms_function
from
....mindspore_test_framework.utils.bprop_util
import
bprop
from
....mindspore_test_framework.utils.bprop_util
import
bprop
from
....mindspore_test_framework.utils.debug_util
import
PrintShapeTypeCell
,
PrintGradShapeTypeCell
from
....mindspore_test_framework.utils.debug_util
import
PrintShapeTypeCell
,
PrintGradShapeTypeCell
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
context
import
mindspore
def
setup_module
(
module
):
def
setup_module
(
module
):
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
)
...
@@ -107,3 +108,36 @@ def test_print_shape_type():
...
@@ -107,3 +108,36 @@ def test_print_shape_type():
return
z
return
z
bprop
(
Mul
(),
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
)),
bprop
(
Mul
(),
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
)),
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
)))
Tensor
(
np
.
ones
([
2
,
2
]).
astype
(
np
.
float32
)))
def
test_cell_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
class
GradNetWrap
(
nn
.
Cell
):
""" GradNetWrap definition """
def
__init__
(
self
,
net
):
super
(
GradNetWrap
,
self
).
__init__
()
self
.
net
=
net
self
.
weights
=
mindspore
.
ParameterTuple
(
net
.
get_parameters
())
def
construct
(
self
,
x
,
y
):
return
C
.
grad_by_list
(
self
.
net
,
self
.
weights
)(
x
,
y
)
class
Mul
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Mul
,
self
).
__init__
()
self
.
get_g
=
P
.
InsertGradientOf
(
self
.
save_gradient
)
self
.
matrix_w
=
mindspore
.
Parameter
(
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
)),
name
=
"matrix_w"
)
self
.
matrix_g
=
mindspore
.
Parameter
(
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
)),
name
=
"matrix_g"
)
def
save_gradient
(
self
,
dout
):
self
.
matrix_g
=
dout
return
dout
def
construct
(
self
,
x
,
y
):
z
=
x
*
self
.
matrix_w
z
=
self
.
get_g
(
z
)
z
=
z
*
y
return
z
input_x
=
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
))
input_y
=
Tensor
(
np
.
ones
([
2
,
2
],
np
.
float32
))
GradNetWrap
(
Mul
())(
input_x
,
input_y
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录