Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
73ea9b78
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
73ea9b78
编写于
7月 23, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix mix precesion operator issue
上级
0a2980ca
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
40 addition
and
13 deletion
+40
-13
mindspore/_extends/builtin_operations.py
mindspore/_extends/builtin_operations.py
+10
-0
mindspore/ccsrc/pipeline/pynative/base.h
mindspore/ccsrc/pipeline/pynative/base.h
+1
-1
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+8
-2
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
+1
-0
mindspore/nn/cell.py
mindspore/nn/cell.py
+12
-10
mindspore/nn/wrap/loss_scale.py
mindspore/nn/wrap/loss_scale.py
+1
-0
tests/ut/python/ops/test_control_ops.py
tests/ut/python/ops/test_control_ops.py
+7
-0
未找到文件。
mindspore/_extends/builtin_operations.py
浏览文件 @
73ea9b78
...
...
@@ -14,6 +14,7 @@
# ============================================================================
"""builtin_operations"""
import
numpy
as
np
from
mindspore.ops
import
functional
as
F
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.dtype
import
dtype_to_nptype
,
get_py_obj_dtype
...
...
@@ -171,3 +172,12 @@ def tuple_to_array(x):
def
stop_gradient
(
x
):
"""Implement `stop_gradient`."""
return
x
def
mixed_precision_cast
(
dst_type
,
x
):
"""Implement `mixed_precision_cast`."""
if
isinstance
(
x
,
tuple
):
res
=
list
()
for
item
in
x
:
res
.
append
(
F
.
cast
(
item
,
dst_type
))
return
tuple
(
res
)
return
F
.
cast
(
x
,
dst_type
)
mindspore/ccsrc/pipeline/pynative/base.h
浏览文件 @
73ea9b78
...
...
@@ -61,7 +61,7 @@ struct OpExecInfo {
using
OpExecInfoPtr
=
std
::
shared_ptr
<
OpExecInfo
>
;
OpExecInfoPtr
GenerateOpExecInfo
(
const
py
::
args
&
args
,
py
::
list
*
const
out_args
);
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
};
const
std
::
set
<
std
::
string
>
ignore_infer_prim
=
{
"make_ref"
,
"mixed_precision_cast"
};
}
// namespace pynative
}
// namespace mindspore
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
73ea9b78
...
...
@@ -57,7 +57,7 @@ using mindspore::tensor::TensorPy;
const
char
SINGLE_OP_GRAPH
[]
=
"single_op_graph"
;
// primitive unable to infer value for constant input in PyNative mode
const
std
::
set
<
std
::
string
>
vm_operators
=
{
"make_ref"
,
"HookBackward"
,
"stop_gradient"
};
const
std
::
set
<
std
::
string
>
vm_operators
=
{
"make_ref"
,
"HookBackward"
,
"stop_gradient"
,
"mixed_precision_cast"
};
namespace
mindspore
{
namespace
pynative
{
...
...
@@ -815,6 +815,9 @@ PynativeExecutor::PynativeExecutor() { grad_flag_ = false; }
void
PynativeExecutor
::
NewGraphInner
(
const
py
::
object
&
cell
,
const
py
::
args
&
args
)
{
auto
cell_id
=
GetId
(
cell
);
if
(
cell_graph_map_
.
count
(
cell_id
)
!=
0
)
{
if
(
cell_resource_map_
.
find
(
cell_id
)
!=
cell_resource_map_
.
end
())
{
resource_
=
cell_resource_map_
[
cell_id
];
}
MS_LOG
(
DEBUG
)
<<
"Newgraph already compiled"
;
return
;
}
...
...
@@ -823,6 +826,8 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg
if
(
top_g_
==
nullptr
)
{
top_g_
=
curr_g_
=
g
;
resource_
=
std
::
make_shared
<
pipeline
::
Resource
>
();
cell_resource_map_
[
cell_id
]
=
resource_
;
df_builder_
=
std
::
make_shared
<
FuncGraph
>
();
MS_LOG
(
DEBUG
)
<<
"First new graph"
<<
top_g_
.
get
();
Pushp
();
...
...
@@ -1124,6 +1129,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG
(
DEBUG
)
<<
"Clear res"
;
(
void
)
graph_map_
.
erase
(
flag
);
(
void
)
cell_graph_map_
.
erase
(
flag
);
(
void
)
cell_resource_map_
.
erase
(
flag
);
Clean
();
// Maybe exit in the pynative runing op, so need reset pynative flag.
auto
ms_context
=
MsContext
::
GetInstance
();
...
...
@@ -1135,6 +1141,7 @@ void PynativeExecutor::Clear(const std::string &flag) {
MS_LOG
(
DEBUG
)
<<
"Clear"
;
top_g_
=
nullptr
;
df_builder_
=
nullptr
;
curr_g_
=
nullptr
;
graph_info_map_
.
clear
();
op_id_map_
.
clear
();
...
...
@@ -1146,7 +1153,6 @@ void PynativeExecutor::Clean() {
Clear
();
grad_flag_
=
false
;
op_forward_map_
.
clear
();
df_builder_
=
nullptr
;
ad
::
CleanRes
();
pipeline
::
ReclaimOptimizer
();
}
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.h
浏览文件 @
73ea9b78
...
...
@@ -119,6 +119,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
bool
grad_flag_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
graph_map_
;
std
::
unordered_map
<
std
::
string
,
FuncGraphPtr
>
cell_graph_map_
;
std
::
unordered_map
<
std
::
string
,
ResourcePtr
>
cell_resource_map_
;
std
::
unordered_map
<
FuncGraphPtr
,
GraphInfo
>
graph_info_map_
;
std
::
unordered_map
<
std
::
string
,
ValuePtr
>
op_forward_map_
;
std
::
unordered_map
<
std
::
string
,
size_t
>
op_id_map_
;
...
...
mindspore/nn/cell.py
浏览文件 @
73ea9b78
...
...
@@ -240,12 +240,13 @@ class Cell:
else
:
_pynative_exec
.
set_grad_flag
(
False
)
cast_inputs
=
list
()
if
hasattr
(
self
,
"_mindspore_flags"
)
and
self
.
_mindspore_flags
.
get
(
'fp16'
):
for
item
in
inputs
:
cast_inputs
.
append
(
cast
(
item
,
mstype
.
float16
))
if
hasattr
(
self
,
"_mindspore_flags"
)
and
self
.
_mindspore_flags
.
get
(
'fp32'
):
for
item
in
inputs
:
cast_inputs
.
append
(
cast
(
item
,
mstype
.
float32
))
if
hasattr
(
self
,
"_mindspore_flags"
):
if
self
.
_mindspore_flags
.
get
(
'fp16'
):
for
item
in
inputs
:
cast_inputs
.
append
(
cast
(
item
,
mstype
.
float16
))
if
self
.
_mindspore_flags
.
get
(
'fp32'
):
for
item
in
inputs
:
cast_inputs
.
append
(
cast
(
item
,
mstype
.
float32
))
if
cast_inputs
:
cast_inputs
=
tuple
(
cast_inputs
)
else
:
...
...
@@ -496,10 +497,11 @@ class Cell:
Args:
param (Parameter): The parameter to cast.
"""
if
hasattr
(
self
,
"_mindspore_flags"
)
and
self
.
_mindspore_flags
.
get
(
'fp16'
):
return
cast
(
param
,
mstype
.
float16
)
if
hasattr
(
self
,
"_mindspore_flags"
)
and
self
.
_mindspore_flags
.
get
(
'fp32'
):
return
cast
(
param
,
mstype
.
float32
)
if
hasattr
(
self
,
"_mindspore_flags"
):
if
self
.
_mindspore_flags
.
get
(
'fp16'
):
return
cast
(
param
,
mstype
.
float16
)
if
self
.
_mindspore_flags
.
get
(
'fp32'
):
return
cast
(
param
,
mstype
.
float32
)
return
param
def
insert_child_to_cell
(
self
,
child_name
,
child
):
...
...
mindspore/nn/wrap/loss_scale.py
浏览文件 @
73ea9b78
...
...
@@ -206,6 +206,7 @@ class TrainOneStepWithLossScaleCell(Cell):
def
__init__
(
self
,
network
,
optimizer
,
scale_update_cell
=
None
):
super
(
TrainOneStepWithLossScaleCell
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
network
=
network
self
.
network
.
set_grad
()
self
.
network
.
add_flags
(
defer_inline
=
True
)
self
.
weights
=
optimizer
.
parameters
self
.
optimizer
=
optimizer
...
...
tests/ut/python/ops/test_control_ops.py
浏览文件 @
73ea9b78
...
...
@@ -20,6 +20,7 @@ import mindspore as ms
from
mindspore
import
Tensor
from
mindspore
import
context
from
mindspore
import
nn
from
mindspore.common
import
dtype
as
mstype
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
operations
as
P
...
...
@@ -638,3 +639,9 @@ def test_large_for_loop_with_continue_break():
t
=
Tensor
(
np
.
ones
([
2
,
3
],
dtype
=
np
.
float32
))
net
=
Net
()
net
(
t
)
def
test_mixed_precision_cast
():
x
=
Tensor
(
np
.
ones
([
2
,
3
],
dtype
=
np
.
float32
))
z
=
F
.
mixed_precision_cast
(
mstype
.
float16
,
x
)
assert
z
.
dtype
==
mstype
.
float16
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录