Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2a2dd7d3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a2dd7d3
编写于
6月 04, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
6月 04, 2020
浏览文件
操作
浏览文件
下载
差异文件
!1803 Optimized MixedPrecisionCast function
Merge pull request !1803 from Kang/master
上级
a4e1852d
041f628c
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
83 addition
and
41 deletion
+83
-41
mindspore/ccsrc/operator/ops.cc
mindspore/ccsrc/operator/ops.cc
+1
-0
mindspore/ccsrc/operator/ops.h
mindspore/ccsrc/operator/ops.h
+1
-0
mindspore/ccsrc/pipeline/parse/parse.cc
mindspore/ccsrc/pipeline/parse/parse.cc
+1
-1
mindspore/ccsrc/pipeline/static_analysis/prim.cc
mindspore/ccsrc/pipeline/static_analysis/prim.cc
+57
-6
mindspore/ccsrc/pipeline/static_analysis/prim.h
mindspore/ccsrc/pipeline/static_analysis/prim.h
+16
-0
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
+4
-0
mindspore/nn/wrap/cell_wrapper.py
mindspore/nn/wrap/cell_wrapper.py
+1
-2
mindspore/ops/composite/base.py
mindspore/ops/composite/base.py
+0
-30
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
mindspore/train/amp.py
mindspore/train/amp.py
+1
-2
未找到文件。
mindspore/ccsrc/operator/ops.cc
浏览文件 @
2a2dd7d3
...
...
@@ -242,6 +242,7 @@ const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
const
PrimitivePtr
kPrimIsNot
=
std
::
make_shared
<
Primitive
>
(
"is_not"
);
const
PrimitivePtr
kPrimInDict
=
std
::
make_shared
<
Primitive
>
(
"in_dict"
);
const
PrimitivePtr
kPrimNotInDict
=
std
::
make_shared
<
Primitive
>
(
"not_in_dict"
);
const
PrimitivePtr
kPrimMixedPrecisionCast
=
std
::
make_shared
<
Primitive
>
(
"mixed_precision_cast"
);
// Comm ops
const
PrimitivePtr
kPrimMirror
=
std
::
make_shared
<
Primitive
>
(
"_MirrorOperator"
);
...
...
mindspore/ccsrc/operator/ops.h
浏览文件 @
2a2dd7d3
...
...
@@ -251,6 +251,7 @@ extern const PrimitivePtr kPrimIs_;
extern
const
PrimitivePtr
kPrimIsNot
;
extern
const
PrimitivePtr
kPrimInDict
;
extern
const
PrimitivePtr
kPrimNotInDict
;
extern
const
PrimitivePtr
kPrimMixedPrecisionCast
;
// Comm ops
extern
const
PrimitivePtr
kPrimMirror
;
...
...
mindspore/ccsrc/pipeline/parse/parse.cc
浏览文件 @
2a2dd7d3
...
...
@@ -67,7 +67,7 @@ AnfNodePtr GetMixedPrecisionCastHelp(const FuncGraphPtr &func_graph, const AnfNo
}
else
{
return
param
;
}
auto
cast_helper
=
prim
::
GetPythonOps
(
"_mp_cast_helper"
,
"mindspore.ops.composite.base"
)
;
auto
cast_helper
=
prim
::
kPrimMixedPrecisionCast
;
auto
cast
=
func_graph
->
NewCNode
({
NewValueNode
(
cast_helper
),
NewValueNode
(
dst_type
),
param
});
return
cast
;
}
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.cc
浏览文件 @
2a2dd7d3
...
...
@@ -147,9 +147,6 @@ EvalResultPtr StandardPrimEvaluator::EvalPrim(const AnalysisEnginePtr &engine, c
EvalResultPtr
DoSignatureEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
if
(
!
prim_
->
isa
<
prim
::
DoSignaturePrimitive
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Primitive should be DoSignature, but "
<<
prim_
->
ToString
();
}
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
}
...
...
@@ -221,9 +218,6 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
}
if
(
!
prim_
->
isa
<
prim
::
UnpackGraphPrimitive
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Primitive should be UnpackGraphPrimitive, but got "
<<
prim_
->
ToString
();
}
auto
unpack_graph
=
prim_
->
cast
<
prim
::
UnpackGraphPrimitivePtr
>
();
auto
out_node
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
...
...
@@ -267,6 +261,63 @@ EvalResultPtr UnpackGraphEvaluator::Run(AnalysisEnginePtr engine, const ConfigPt
return
engine
->
ForwardConfig
(
out_conf
,
fn_conf
);
}
AnfNodePtr
MixedPrecisionCastHelper
(
AnfNodePtr
source_node
,
AbstractBasePtr
node_type
,
AnfNodePtr
target_type
,
FuncGraphPtr
func_graph
)
{
AnfNodePtr
target_node
=
source_node
;
if
(
node_type
->
isa
<
AbstractTensor
>
())
{
auto
x
=
node_type
->
cast
<
AbstractTensorPtr
>
();
if
(
x
->
element
()
->
BuildType
()
->
isa
<
Float
>
())
{
auto
cast
=
prim
::
GetPythonOps
(
"cast"
,
"mindspore.ops.functional"
);
MS_EXCEPTION_IF_NULL
(
cast
);
target_node
=
func_graph
->
NewCNode
({
NewValueNode
(
cast
),
source_node
,
target_type
});
}
}
else
if
(
node_type
->
isa
<
AbstractTuple
>
())
{
auto
x
=
node_type
->
cast
<
AbstractTuplePtr
>
();
auto
&
items
=
x
->
elements
();
std
::
size_t
size
=
items
.
size
();
std
::
vector
<
AnfNodePtr
>
nodes
;
nodes
.
emplace_back
(
NewValueNode
(
prim
::
kPrimMakeTuple
));
for
(
int
i
=
0
;
i
<
SizeToInt
(
size
);
i
++
)
{
AnfNodePtr
tuple_node
=
func_graph
->
NewCNode
({
NewValueNode
(
prim
::
kPrimTupleGetItem
),
source_node
,
NewValueNode
(
i
)});
AnfNodePtr
node
=
MixedPrecisionCastHelper
(
tuple_node
,
items
[
i
],
target_type
,
func_graph
);
nodes
.
emplace_back
(
node
);
}
target_node
=
func_graph
->
NewCNode
(
nodes
);
}
return
target_node
;
}
EvalResultPtr
MixedPrecisionCastEvaluator
::
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
args_conf_list
,
AnfNodeConfigPtr
out_conf
)
{
AbstractBasePtrList
args_spec_list
;
if
(
out_conf
->
node
()
==
nullptr
||
!
out_conf
->
node
()
->
isa
<
CNode
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"Node of out_conf should be CNode"
;
}
auto
out_node
=
out_conf
->
node
()
->
cast
<
CNodePtr
>
();
const
auto
&
out_node_inputs
=
out_node
->
inputs
();
if
(
out_node
->
inputs
().
size
()
==
0
||
(
out_node_inputs
.
size
()
-
1
)
!=
args_conf_list
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"MixedPrecisionCast"
<<
" args size should equal to inputs size minus 1, but args size "
<<
args_conf_list
.
size
()
<<
", inputs size "
<<
out_node_inputs
.
size
();
}
AnfNodePtrList
args_inputs
{
out_node_inputs
.
begin
()
+
1
,
out_node_inputs
.
end
()};
(
void
)
std
::
transform
(
args_conf_list
.
begin
(),
args_conf_list
.
end
(),
std
::
back_inserter
(
args_spec_list
),
[](
const
ConfigPtr
&
ref
)
->
AbstractBasePtr
{
return
ref
->
GetEvaluatedValue
()
->
abstract
();
});
ScopePtr
scope
=
kDefaultScope
;
if
(
out_conf
!=
nullptr
)
{
scope
=
out_conf
->
node
()
->
scope
();
}
ScopeGuard
scope_guard
(
scope
);
FuncGraphPtr
func_graph
=
out_conf
->
node
()
->
func_graph
();
AnfNodePtr
new_node
=
MixedPrecisionCastHelper
(
out_node_inputs
[
2
],
args_spec_list
[
1
],
out_node_inputs
[
1
],
func_graph
);
AnfNodeConfigPtr
fn_conf
=
engine
->
MakeConfig
(
new_node
,
out_conf
->
context
());
return
engine
->
ForwardConfig
(
out_conf
,
fn_conf
);
}
namespace
{
py
::
object
BuildValue
(
const
ValuePtr
&
value_ptr
)
{
if
(
value_ptr
==
nullptr
)
{
...
...
mindspore/ccsrc/pipeline/static_analysis/prim.h
浏览文件 @
2a2dd7d3
...
...
@@ -102,6 +102,22 @@ class UnpackGraphEvaluator : public Evaluator {
PrimitivePtr
prim_
;
};
class
MixedPrecisionCastEvaluator
:
public
Evaluator
{
public:
explicit
MixedPrecisionCastEvaluator
(
const
PrimitivePtr
primitive
)
:
Evaluator
(
"MixedPrecisionCastEvaluator"
),
prim_
(
primitive
)
{}
~
MixedPrecisionCastEvaluator
()
override
=
default
;
EvalResultPtr
Run
(
AnalysisEnginePtr
engine
,
const
ConfigPtrList
&
argrefs
,
AnfNodeConfigPtr
out_config
=
nullptr
)
override
;
EvalResultPtr
Eval
(
AnalysisEnginePtr
,
const
AbstractBasePtrList
&
)
override
{
MS_LOG
(
EXCEPTION
)
<<
"Eval() should not be called, Run() method should be called"
;
}
private:
PrimitivePtr
prim_
;
};
bool
IsInWhiteList
(
PrimitivePtr
primitive
);
StandardPrimitiveEvalImpl
GetPrimitiveInferImpl
(
const
PrimitivePtr
&
primitive
);
...
...
mindspore/ccsrc/pipeline/static_analysis/static_analysis.cc
浏览文件 @
2a2dd7d3
...
...
@@ -308,6 +308,10 @@ EvaluatorPtr GetPrimEvaluator(const PrimitivePtr &prim, const AnalysisEnginePtr
evaluator
=
std
::
make_shared
<
UnpackGraphEvaluator
>
(
prim
);
return
evaluator
;
}
if
(
prim
->
name
()
==
prim
::
kPrimMixedPrecisionCast
->
name
())
{
evaluator
=
std
::
make_shared
<
MixedPrecisionCastEvaluator
>
(
prim
);
return
evaluator
;
}
if
(
prim
->
HasPyEvaluator
())
{
auto
prim_py
=
dyn_cast
<
PrimitivePy
>
(
prim
);
if
(
prim_py
!=
nullptr
)
{
...
...
mindspore/nn/wrap/cell_wrapper.py
浏览文件 @
2a2dd7d3
...
...
@@ -21,7 +21,6 @@ from ...common.parameter import Parameter, ParameterTuple
from
...ops
import
composite
as
C
from
...ops
import
functional
as
F
from
...ops
import
operations
as
P
from
...ops.composite.base
import
_mp_cast_helper
from
...ops.operations.comm_ops
import
_VirtualDataset
from
..cell
import
Cell
from
.grad_reducer
import
DistributedGradReducer
...
...
@@ -345,7 +344,7 @@ class WithEvalCell(Cell):
def
construct
(
self
,
data
,
label
):
outputs
=
self
.
_network
(
data
)
if
self
.
add_cast_fp32
:
label
=
_mp_cast_helper
(
mstype
.
float32
,
label
)
label
=
F
.
mixed_precision_cast
(
mstype
.
float32
,
label
)
outputs
=
F
.
cast
(
outputs
,
mstype
.
float32
)
loss
=
self
.
_loss_fn
(
outputs
,
label
)
return
loss
,
outputs
,
label
...
...
mindspore/ops/composite/base.py
浏览文件 @
2a2dd7d3
...
...
@@ -24,7 +24,6 @@ from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeF
from
...common
import
dtype
as
mstype
from
...common.api
import
ms_function
,
_pynative_exec
from
..
import
functional
as
F
from
..
import
operations
as
P
from
...common.parameter
import
Parameter
...
...
@@ -297,32 +296,3 @@ env_get = MultitypeFuncGraph("env_get")
def
_tensor_env_get
(
env
,
parameter
):
"""Used to get env."""
return
F
.
env_getitem
(
env
,
F
.
ref_to_embed
(
parameter
),
F
.
zeros_like
(
parameter
))
_mp_cast_helper
=
MultitypeFuncGraph
(
'mixed_precision_cast_helper'
)
@
_mp_cast_helper
.
register
(
"TypeType"
,
"Number"
)
@
core
def
_mixed_precision_cast_helper_1
(
type_
,
x
):
"""if x is float cast to type."""
# type_ is place holder
return
x
@
_mp_cast_helper
.
register
(
"TypeType"
,
"Tensor"
)
@
core
def
_mixed_precision_cast_helper_2
(
type_
,
x
):
"""if x is float cast to type."""
if
F
.
issubclass_
(
F
.
dtype
(
x
),
mstype
.
float_
):
return
P
.
Cast
()(
x
,
type_
)
return
x
@
_mp_cast_helper
.
register
(
"TypeType"
,
"Tuple"
)
@
core
def
_mixed_precision_cast_helper_3
(
type_
,
x
):
"""if x is a tuple"""
t
=
()
for
item
in
x
:
t
=
t
+
(
_mp_cast_helper
(
type_
,
item
),)
return
t
mindspore/ops/functional.py
浏览文件 @
2a2dd7d3
...
...
@@ -126,6 +126,7 @@ is_ = Primitive("is_")
is_not
=
Primitive
(
"is_not"
)
in_dict
=
Primitive
(
"in_dict"
)
not_in_dict
=
Primitive
(
"not_in_dict"
)
mixed_precision_cast
=
Primitive
(
"mixed_precision_cast"
)
broadcast_gradient_args
=
Primitive
(
'BroadcastGradientArgs'
)
dot
=
Primitive
(
'dot'
)
array_reduce
=
Primitive
(
'array_reduce'
)
...
...
mindspore/train/amp.py
浏览文件 @
2a2dd7d3
...
...
@@ -21,7 +21,6 @@ from .._checkparam import Rel
from
..common
import
dtype
as
mstype
from
..nn.wrap.cell_wrapper
import
_VirtualDatasetCell
from
..ops
import
functional
as
F
from
..ops.composite.base
import
_mp_cast_helper
from
..parallel._utils
import
_get_parallel_mode
from
.loss_scale_manager
import
DynamicLossScaleManager
,
LossScaleManager
from
.parallel_utils
import
ParallelMode
...
...
@@ -98,7 +97,7 @@ def _add_loss_network(network, loss_fn, cast_model_type):
def
construct
(
self
,
data
,
label
):
out
=
self
.
_backbone
(
data
)
label
=
_mp_cast_helper
(
mstype
.
float32
,
label
)
label
=
F
.
mixed_precision_cast
(
mstype
.
float32
,
label
)
return
self
.
_loss_fn
(
F
.
cast
(
out
,
mstype
.
float32
),
label
)
validator
.
check_value_type
(
'loss_fn'
,
loss_fn
,
nn
.
Cell
,
None
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录