Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
9cb665e6
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9cb665e6
编写于
4月 17, 2020
作者:
H
huangdongrun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add suport for parameter of const value pass as mixed precision args
fix pylint
上级
549bfb97
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
51 addition
and
4 deletion
+51
-4
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+1
-3
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+9
-0
tests/ut/python/parameter_feature/test_parameter.py
tests/ut/python/parameter_feature/test_parameter.py
+41
-1
未找到文件。
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
9cb665e6
...
@@ -68,9 +68,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
...
@@ -68,9 +68,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
return
param
;
return
param
;
}
}
auto
cast_helper
=
prim
::
GetPythonOps
(
"_mp_cast_helper"
,
"mindspore.ops.composite.base"
);
auto
cast_helper
=
prim
::
GetPythonOps
(
"_mp_cast_helper"
,
"mindspore.ops.composite.base"
);
auto
partial
=
auto
cast
=
func_graph
->
NewCNode
({
NewValueNode
(
cast_helper
),
NewValueNode
(
dst_type
),
param
});
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimPartial
),
NewValueNode
(
cast_helper
),
NewValueNode
(
dst_type
)});
auto
cast
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kCompositeHyperMap
),
partial
,
param
});
return
cast
;
return
cast
;
}
}
...
...
mindspore/ops/composite/base.py
浏览文件 @
9cb665e6
...
@@ -307,3 +307,12 @@ def _mixed_precision_cast_helper_2(type_, x):
...
@@ -307,3 +307,12 @@ def _mixed_precision_cast_helper_2(type_, x):
if
F
.
issubclass_
(
F
.
dtype
(
x
),
mstype
.
float_
):
if
F
.
issubclass_
(
F
.
dtype
(
x
),
mstype
.
float_
):
return
P
.
Cast
()(
x
,
type_
)
return
P
.
Cast
()(
x
,
type_
)
return
x
return
x
@
_mp_cast_helper
.
register
(
"TypeType"
,
"Tuple"
)
@
core
def
_mixed_precision_cast_helper_3
(
type_
,
x
):
"""if x is a tuple"""
t
=
()
for
item
in
x
:
t
=
t
+
(
_mp_cast_helper
(
type_
,
item
),)
return
t
tests/ut/python/parameter_feature/test_parameter.py
浏览文件 @
9cb665e6
...
@@ -19,7 +19,7 @@ from mindspore.nn import Cell
...
@@ -19,7 +19,7 @@ from mindspore.nn import Cell
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
import
mindspore.ops.composite
as
C
import
mindspore.ops.composite
as
C
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
save_graphs
=
True
)
def
test_parser_three_default_mixed_args_subnet
():
def
test_parser_three_default_mixed_args_subnet
():
...
@@ -227,3 +227,43 @@ def test_net_vargs_expand():
...
@@ -227,3 +227,43 @@ def test_net_vargs_expand():
net
.
set_train
()
net
.
set_train
()
net
(
x
,
y
,
sens
)
net
(
x
,
y
,
sens
)
def
test_mixed_precision_const_parameter
():
class
NetLoss
(
Cell
):
def
__init__
(
self
):
super
(
NetLoss
,
self
).
__init__
()
self
.
shape
=
P
.
Shape
()
self
.
up_sample1
=
P
.
ResizeBilinear
((
14
,
14
))
self
.
up_sample2
=
P
.
ResizeBilinear
((
28
,
28
))
self
.
up_sample3
=
P
.
ResizeBilinear
((
36
,
36
))
def
construct
(
self
,
x
,
y
,
z
,
*
args
):
ret
=
0
if
args
[
0
]
==
self
.
shape
(
z
)[
2
]:
if
args
[
0
]
==
14
:
ret
=
self
.
up_sample1
(
y
)
+
x
elif
args
[
0
]
==
28
:
ret
=
self
.
up_sample2
(
y
)
-
x
else
:
ret
=
x
/
y
else
:
ret
=
x
*
y
ret
=
ret
*
z
return
ret
class
NetMain
(
Cell
):
def
__init__
(
self
,
loss_fn
):
super
(
NetMain
,
self
).
__init__
()
self
.
loss_fn
=
loss_fn
self
.
shape
=
P
.
Shape
()
def
construct
(
self
,
x
,
y
,
z
):
size_x
=
self
.
shape
(
x
)[
2
]
size_y
=
self
.
shape
(
y
)[
2
]
ret
=
self
.
loss_fn
(
x
,
y
,
z
,
size_x
,
size_y
)
return
ret
loss_fn
=
NetLoss
()
net
=
NetMain
(
loss_fn
)
net
.
add_flags_recursive
(
fp32
=
True
)
x
=
Tensor
(
np
.
ones
((
1
,
3
,
28
,
28
),
np
.
float32
))
y
=
Tensor
(
np
.
ones
((
1
,
3
,
14
,
14
),
np
.
float32
))
z
=
Tensor
(
np
.
ones
((
1
,
3
,
28
,
28
),
np
.
float32
))
out
=
net
(
x
,
y
,
z
)
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录