Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
fab9fac1
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fab9fac1
编写于
7月 28, 2020
作者:
K
kingfo
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix batchnorm under mix precision in pynative mode
上级
b75943f2
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
50 addition
and
11 deletion
+50
-11
mindspore/_extends/builtin_operations.py
mindspore/_extends/builtin_operations.py
+7
-6
mindspore/common/parameter.py
mindspore/common/parameter.py
+11
-0
mindspore/common/tensor.py
mindspore/common/tensor.py
+1
-1
mindspore/nn/cell.py
mindspore/nn/cell.py
+7
-3
mindspore/ops/functional.py
mindspore/ops/functional.py
+1
-0
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+1
-0
mindspore/ops/primitive.py
mindspore/ops/primitive.py
+22
-1
未找到文件。
mindspore/_extends/builtin_operations.py
浏览文件 @
fab9fac1
...
...
@@ -15,6 +15,7 @@
"""builtin_operations"""
import
numpy
as
np
from
mindspore.ops
import
functional
as
F
from
mindspore.ops
import
composite
as
C
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.dtype
import
dtype_to_nptype
,
get_py_obj_dtype
...
...
@@ -173,11 +174,11 @@ def stop_gradient(x):
"""Implement `stop_gradient`."""
return
x
hyper_map
=
C
.
HyperMap
()
def
mixed_precision_cast
(
dst_type
,
x
):
"""Implement `mixed_precision_cast`."""
if
isinstance
(
x
,
tuple
):
res
=
list
()
for
item
in
x
:
res
.
append
(
F
.
cast
(
item
,
dst_type
))
return
tuple
(
res
)
return
F
.
cast
(
x
,
dst_type
)
def
cast_inner
(
data
):
return
F
.
cast
(
data
,
dst_type
)
return
hyper_map
(
cast_inner
,
x
)
mindspore/common/parameter.py
浏览文件 @
fab9fac1
...
...
@@ -61,6 +61,7 @@ class Parameter:
self
.
_is_init
=
False
self
.
_sliced
=
False
self
.
is_param_ps
=
False
self
.
_cast_type
=
None
self
.
init_in_server
=
False
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
self
.
init_data
()
...
...
@@ -103,6 +104,16 @@ class Parameter:
raise
ValueError
(
"The type of the name should be `str` or `None`."
)
self
.
_value
.
name
=
name_
@
property
def
cast_type
(
self
):
return
self
.
_cast_type
@
cast_type
.
setter
def
cast_type
(
self
,
dst_type
):
if
dst_type
not
in
(
mstype
.
float16
,
mstype
.
float32
,
None
):
raise
ValueError
(
"The type of the name should be type of [float32, float16] or `None`."
)
self
.
_cast_type
=
dst_type
@
property
def
sliced
(
self
):
"""Get slice status of the parameter."""
...
...
mindspore/common/tensor.py
浏览文件 @
fab9fac1
mindspore/nn/cell.py
浏览文件 @
fab9fac1
...
...
@@ -286,6 +286,8 @@ class Cell:
if
context
.
get_context
(
"mode"
)
==
context
.
PYNATIVE_MODE
:
if
name
in
self
.
__dict__
:
del
self
.
__dict__
[
name
]
if
name
in
params
:
del
params
[
name
]
params_list
[
name
]
=
value
else
:
object
.
__setattr__
(
self
,
name
,
value
)
...
...
@@ -499,9 +501,11 @@ class Cell:
"""
if
hasattr
(
self
,
"_mindspore_flags"
):
if
self
.
_mindspore_flags
.
get
(
'fp16'
):
return
cast
(
param
,
mstype
.
float16
)
if
self
.
_mindspore_flags
.
get
(
'fp32'
):
return
cast
(
param
,
mstype
.
float32
)
param
.
cast_type
=
mstype
.
float16
elif
self
.
_mindspore_flags
.
get
(
'fp32'
):
param
.
cast_type
=
mstype
.
float32
else
:
param
.
cast_type
=
None
return
param
def
insert_child_to_cell
(
self
,
child_name
,
child
):
...
...
mindspore/ops/functional.py
浏览文件 @
fab9fac1
...
...
@@ -183,3 +183,4 @@ tensor_operator_registry.register('__ge__', tensor_ge)
tensor_operator_registry
.
register
(
'shape'
,
shape
)
#support GE backend for no compare operators
tensor_operator_registry
.
register
(
'vm_compare'
,
BP
.
vm_compare
)
tensor_operator_registry
.
register
(
'cast'
,
cast
)
mindspore/ops/operations/nn_ops.py
浏览文件 @
fab9fac1
...
...
@@ -618,6 +618,7 @@ class FusedBatchNorm(Primitive):
self
.
mode
=
validator
.
check_integer
(
'mode'
,
mode
,
[
0
,
1
],
Rel
.
IN
,
self
.
name
)
self
.
epsilon
=
validator
.
check_number_range
(
'epsilon'
,
epsilon
,
0
,
1
,
Rel
.
INC_RIGHT
,
self
.
name
)
self
.
momentum
=
validator
.
check_number_range
(
'momentum'
,
momentum
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
_update_parameter
=
True
class
BNTrainingReduce
(
PrimitiveWithInfer
):
...
...
mindspore/ops/primitive.py
浏览文件 @
fab9fac1
...
...
@@ -18,6 +18,8 @@
import
inspect
import
copy
from
mindspore.common.api
import
_wrap_func
from
mindspore.common
import
Parameter
from
mindspore.common._register_for_tensor
import
tensor_operator_registry
from
.._c_expression
import
Primitive_
,
real_run_op
,
prim_type
from
.._c_expression
import
signature_rw
as
sig_rw
from
.._c_expression
import
signature_kind
as
sig_kind
...
...
@@ -49,6 +51,7 @@ class Primitive(Primitive_):
self
.
name
=
name
self
.
attrs
=
{}
self
.
init_attrs
=
{
"name"
:
name
}
self
.
_update_parameter
=
False
Primitive_
.
__init__
(
self
,
name
,
self
)
if
hasattr
(
self
.
__class__
,
'__mindspore_signature__'
):
sig
=
self
.
_fill_signature
(
self
.
__class__
.
__mindspore_signature__
)
...
...
@@ -189,6 +192,11 @@ class Primitive(Primitive_):
# for checking output number with kernel implementation
self
.
add_prim_attr
(
"output_names"
,
outputs
)
@
property
def
update_parameter
(
self
):
""" Whether the primitive will update the value of parameter."""
return
self
.
_update_parameter
class
PrimitiveWithInfer
(
Primitive
):
"""
...
...
@@ -359,7 +367,20 @@ def constexpr(fn=None, get_instance=True, name=None):
@
_wrap_func
def
_run_op
(
obj
,
op_name
,
args
):
"""Single op execution function supported by ge in PyNative mode."""
output
=
real_run_op
(
obj
,
op_name
,
args
)
cast
=
tensor_operator_registry
.
get
(
"cast"
)
if
op_name
==
"Cast"
or
obj
.
update_parameter
:
cast_args
=
args
else
:
cast_args
=
list
()
for
arg
in
args
:
if
isinstance
(
arg
,
Parameter
):
if
arg
.
cast_type
:
cast_args
.
append
(
cast
(
arg
,
arg
.
cast_type
))
else
:
cast_args
.
append
(
arg
)
else
:
cast_args
.
append
(
arg
)
output
=
real_run_op
(
obj
,
op_name
,
tuple
(
cast_args
))
if
not
output
:
raise
RuntimeError
(
"Pynative run op %s failed!"
%
op_name
)
if
len
(
output
)
==
1
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录