Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
86ba9362
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
86ba9362
编写于
6月 17, 2020
作者:
W
wandongdong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
split correction_mul op
上级
445122f5
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
133 addition
and
19 deletion
+133
-19
mindspore/ops/_grad/grad_quant_ops.py
mindspore/ops/_grad/grad_quant_ops.py
+12
-2
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
+85
-15
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+36
-2
未找到文件。
mindspore/ops/_grad/grad_quant_ops.py
浏览文件 @
86ba9362
...
...
@@ -18,6 +18,7 @@
from
..
import
operations
as
P
from
.grad_base
import
bprop_getters
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
...
import
context
@
bprop_getters
.
register
(
P
.
FakeQuantPerLayer
)
...
...
@@ -64,12 +65,21 @@ def get_bprop_batchnorm_fold(self):
@
bprop_getters
.
register
(
P
.
CorrectionMul
)
def
get_bprop_correction_mul
(
self
):
"""Generate bprop for CorrectionMul for Ascend and GPU"""
grad
=
P
.
CorrectionMulGrad
(
self
.
channel_axis
)
grad_dx
=
P
.
CorrectionMulGrad
(
self
.
channel_axis
)
grad_d_batch_std
=
P
.
CorrectionMulGradReduce
(
self
.
channel_axis
)
def
bprop
(
x
,
batch_std
,
running_std
,
out
,
dout
):
dx
,
d_batch_std
=
grad
(
dout
,
x
,
batch_std
,
running_std
)
dx
,
d_batch_std
=
grad
_dx
(
dout
,
x
,
batch_std
,
running_std
)
return
dx
,
d_batch_std
,
zeros_like
(
running_std
)
def
bprop_npu
(
x
,
batch_std
,
running_std
,
out
,
dout
):
dx
,
mul_dx
=
grad_dx
(
dout
,
x
,
batch_std
,
running_std
)
d_batch_std
=
grad_d_batch_std
(
mul_dx
)
return
dx
,
d_batch_std
,
zeros_like
(
running_std
)
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
return
bprop_npu
return
bprop
...
...
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
浏览文件 @
86ba9362
...
...
@@ -37,7 +37,7 @@ correction_mul_grad_op_info = TBERegOp("CorrectionMulGrad") \
.
input
(
2
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"running_std"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dx"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"
d_batch_std
"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"
mul_dx
"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
...
...
@@ -56,21 +56,14 @@ def correction_mul_grad_compute(dout, x, batch_std, running_std, channel, data_f
factor
=
te
.
lang
.
cce
.
vdiv
(
batch_std
,
running_std
)
factor_b
=
te
.
lang
.
cce
.
broadcast
(
factor
,
shape_x
)
dx
=
te
.
lang
.
cce
.
vmul
(
dout
,
factor_b
)
mul_data
=
te
.
lang
.
cce
.
vmul
(
dout
,
x
)
if
channel
==
0
:
if
data_format
==
"NCHW"
:
axis
=
[
1
,
2
,
3
]
else
:
axis
=
[
1
,
2
,
3
,
4
]
else
:
axis
=
[
2
,
3
]
red_data
=
te
.
lang
.
cce
.
sum
(
mul_data
,
axis
,
keepdims
=
True
)
d_batch_std
=
te
.
lang
.
cce
.
vdiv
(
red_data
,
running_std
)
return
[
dx
,
d_batch_std
]
mul_dx
=
te
.
lang
.
cce
.
vmul
(
dout
,
x
)
running_std_b
=
te
.
lang
.
cce
.
broadcast
(
running_std
,
shape_x
)
mul_dx
=
te
.
lang
.
cce
.
vdiv
(
mul_dx
,
running_std_b
)
return
[
dx
,
mul_dx
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
int
,
str
)
def
correction_mul_grad
(
dout
,
x
,
batch_std
,
running_std
,
dx
,
d_batch_std
,
channel
,
kernel_name
=
"correction_mul_grad"
):
def
correction_mul_grad
(
dout
,
x
,
batch_std
,
running_std
,
dx
,
mul_dx
,
channel
,
kernel_name
=
"correction_mul_grad"
):
"""CorrectionMulGrad op"""
shape_dout
=
dout
.
get
(
"shape"
)
shape_x
=
dout
.
get
(
"shape"
)
...
...
@@ -93,7 +86,7 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
util
.
compare_tensor_dict_key
(
dout
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
dx
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
batch_std
,
running_std
,
"shape"
)
util
.
compare_tensor_dict_key
(
batch_std
,
d_batch_std
,
"shape"
)
util
.
compare_tensor_dict_key
(
dx
,
mul_dx
,
"shape"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape_x
)
...
...
@@ -120,7 +113,84 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
dout_t
,
x_t
,
batch_std_t
,
running_std_t
]
+
list
(
res_list
)
tensor_list
=
[
dout_t
,
x_t
,
batch_std_t
,
running_std_t
]
+
res_list
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
correction_mul_grad_reduce_op_info
=
TBERegOp
(
"CorrectionMulGradReduce"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"correction_mul_grad_reduce.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul_grad_reduce"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"d_batch_std"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
correction_mul_grad_reduce_op_info
)
def
_correction_mul_grad_reduce_tbe
():
"""CorrectionMulGradReduce TBE register"""
return
@
fusion_manager
.
register
(
"correction_mul_grad_reduce"
)
def
correction_mul_grad_reduce_compute
(
mul_dx
,
channel
,
data_format
,
kernel_name
=
"correction_mul"
):
"""CorrectionMulGradReduce compute"""
if
channel
==
0
:
if
data_format
==
"NCHW"
:
axis
=
[
1
,
2
,
3
]
else
:
axis
=
[
1
,
2
,
3
,
4
]
else
:
axis
=
[
2
,
3
]
d_batch_std
=
te
.
lang
.
cce
.
sum
(
mul_dx
,
axis
,
keepdims
=
True
)
return
d_batch_std
@
util
.
check_input_type
(
dict
,
dict
,
int
,
str
)
def
correction_mul_grad_reduce
(
mul_dx
,
d_batch_std
,
channel
,
kernel_name
=
"correction_mul_grad_reduce"
):
"""CorrectionMulGradReduce op"""
shape_dout
=
mul_dx
.
get
(
"shape"
)
shape_x
=
mul_dx
.
get
(
"shape"
)
dtype_dout
=
mul_dx
.
get
(
"dtype"
)
inp_dtype_dout
=
dtype_dout
.
lower
()
util
.
check_dtype_rule
(
inp_dtype_dout
,
(
"float16"
,
"float32"
))
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape_x
)
util
.
check_shape_size
(
shape_x
,
SHAPE_SIZE_LIMIT
)
data_format
=
mul_dx
.
get
(
"format"
)
ori_format
=
mul_dx
.
get
(
"format"
)
if
data_format
.
upper
()
not
in
(
"NC1HWC0"
,
"NCHW"
):
raise
RuntimeError
(
"Un supported data format {}"
.
format
(
data_format
))
if
data_format
.
upper
()
==
"NCHW"
and
ori_format
!=
"NCHW"
:
raise
RuntimeError
(
"data_format(NCHW) must same as ori_format"
)
shape_c
=
[
1
]
*
len
(
shape_x
)
shape_c
[
channel
]
=
d_batch_std
.
get
(
"ori_shape"
)[
0
]
if
data_format
==
"NC1HWC0"
and
channel
==
1
:
shape_c
=
d_batch_std
.
get
(
"shape"
)
dout_t
=
tvm
.
placeholder
(
shape_dout
,
name
=
"dout"
,
dtype
=
inp_dtype_dout
)
res
=
correction_mul_grad_reduce_compute
(
dout_t
,
channel
,
data_format
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
dout_t
,
res
]
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
86ba9362
...
...
@@ -31,6 +31,7 @@ __all__ = ["FakeQuantPerLayer",
"BatchNormFoldGrad"
,
"CorrectionMul"
,
"CorrectionMulGrad"
,
"CorrectionMulGradReduce"
,
"BatchNormFold2"
,
"BatchNormFold2Grad"
,
"BatchNormFoldD"
,
...
...
@@ -500,7 +501,7 @@ class CorrectionMulGrad(PrimitiveWithInfer):
from
mindspore.ops._op_impl._custom_op
import
correction_mul_grad
self
.
channel_axis
=
channel_axis
self
.
init_prim_io_names
(
inputs
=
[
'dout'
,
'x'
,
'gamma'
,
'running_std'
],
outputs
=
[
'dx'
,
'
d_gamma
'
])
outputs
=
[
'dx'
,
'
mul_dx
'
])
def
infer_shape
(
self
,
dout_shape
,
x_shape
,
gamma_shape
,
running_std_shape
):
validator
.
check
(
"dout shape"
,
dout_shape
,
"x_shape x"
,
x_shape
,
Rel
.
EQ
,
self
.
name
)
...
...
@@ -508,12 +509,45 @@ class CorrectionMulGrad(PrimitiveWithInfer):
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"running_std_shape[0]"
,
running_std_shape
[
0
],
"dout channel size"
,
dout_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
return
x_shape
,
x_shape
return
x_shape
,
gamma_shape
def
infer_dtype
(
self
,
dout_type
,
x_type
,
gamma_type
,
running_std_type
):
args
=
{
"dout"
:
dout_type
,
"x"
:
x_type
,
"gamma"
:
gamma_type
,
"running_std"
:
running_std_type
}
validator
.
check_tensor_type_same
(
args
,
(
mstype
.
float16
,
mstype
.
float32
),
self
.
name
)
return
x_type
,
x_type
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
return
x_type
,
x_type
return
x_type
,
gamma_type
class
CorrectionMulGradReduce
(
PrimitiveWithInfer
):
r
"""
Performs grad reduce of CorrectionMul operation.
Examples:
>>> correction_mul_grad_rd = P.CorrectionMulGradReduce()
>>> dout = Tensor(np.array([1.5, -2.2, 0.7, -3, 1.6, 2.8]).reshape(2, 1, 1, 3), mindspore.float32)
>>> input_x = Tensor(np.random.randint(0, 256, (2, 1, 1, 3)), mindspore.float32)
>>> gamma = Tensor(np.array([0.2, -0.2, 2.5, -1.]).reshape(2, 1, 2), mindspore.float32)
>>> running_std = Tensor(np.array([1.2, 0.1, 0.7, 2.3]).reshape(2, 1, 2), mindspore.float32)
>>> result = correction_mul_grad_rd(dout, input_x, gamma, running_std)
"""
@
prim_attr_register
def
__init__
(
self
,
channel_axis
=
0
):
"""init correction mul reduce layer"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
from
mindspore.ops._op_impl._custom_op
import
correction_mul_grad
self
.
channel_axis
=
channel_axis
self
.
init_prim_io_names
(
inputs
=
[
'mul_dx'
],
outputs
=
[
'd_gamma'
])
def
infer_shape
(
self
,
mul_dx_shape
):
return
[
mul_dx_shape
[
self
.
channel_axis
]]
def
infer_dtype
(
self
,
mul_dx_type
):
return
mul_dx_type
class
BatchNormFold2
(
PrimitiveWithInfer
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录