Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0a52fd05
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0a52fd05
编写于
5月 27, 2020
作者:
W
wandongdong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add custom tbe ops for quant aware training
上级
cf20b344
变更
14
展开全部
隐藏空白更改
内联
并排
Showing
14 changed file
with
2059 addition
and
103 deletion
+2059
-103
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+444
-74
mindspore/ops/_grad/grad_quant_ops.py
mindspore/ops/_grad/grad_quant_ops.py
+51
-6
mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py
+149
-0
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
+110
-0
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
+126
-0
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
...re/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
+107
-0
mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py
mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py
+124
-0
mindspore/ops/_op_impl/_custom_op/correction_mul.py
mindspore/ops/_op_impl/_custom_op/correction_mul.py
+92
-0
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
+134
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
+146
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py
...e/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py
+156
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py
...ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py
+137
-0
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+280
-20
tests/ut/python/train/quant/test_quant.py
tests/ut/python/train/quant/test_quant.py
+3
-3
未找到文件。
mindspore/nn/layer/quant.py
浏览文件 @
0a52fd05
此差异已折叠。
点击以展开。
mindspore/ops/_grad/grad_quant_ops.py
浏览文件 @
0a52fd05
...
...
@@ -22,7 +22,7 @@ from ..composite.multitype_ops.zeros_like_impl import zeros_like
@
bprop_getters
.
register
(
P
.
FakeQuantWithMinMax
)
def
get_bprop_fakequant_with_minmax
(
self
):
"""Generate bprop for FakeQuantWithMinMax"""
"""Generate bprop for FakeQuantWithMinMax
for GPU and Ascend
"""
op
=
P
.
FakeQuantWithMinMaxGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
)
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
...
...
@@ -34,7 +34,7 @@ def get_bprop_fakequant_with_minmax(self):
@
bprop_getters
.
register
(
P
.
FakeQuantWithMinMaxPerChannel
)
def
get_bprop_fakequant_with_minmax_perchannel
(
self
):
"""Generate bprop for FakeQuantWithMinMaxPerChannel"""
"""Generate bprop for FakeQuantWithMinMaxPerChannel
for GPU
"""
op
=
P
.
FakeQuantWithMinMaxPerChannelGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
)
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
...
...
@@ -46,7 +46,7 @@ def get_bprop_fakequant_with_minmax_perchannel(self):
@
bprop_getters
.
register
(
P
.
BatchNormFold
)
def
get_bprop_batchnorm_fold
(
self
):
"""Generate bprop for BatchNormFold"""
"""Generate bprop for BatchNormFold
for GPU
"""
op
=
P
.
BatchNormFoldGrad
(
self
.
epsilon
,
self
.
is_training
,
self
.
freeze_bn
)
def
bprop
(
x
,
mean
,
variance
,
global_step
,
out
,
dout
):
...
...
@@ -58,8 +58,8 @@ def get_bprop_batchnorm_fold(self):
@
bprop_getters
.
register
(
P
.
CorrectionMul
)
def
get_bprop_correction_mul
(
self
):
"""Generate bprop for CorrectionMul"""
grad
=
P
.
CorrectionMulGrad
()
"""Generate bprop for CorrectionMul
for Ascend and GPU
"""
grad
=
P
.
CorrectionMulGrad
(
self
.
channel_axis
)
def
bprop
(
x
,
batch_std
,
running_std
,
out
,
dout
):
dx
,
d_batch_std
=
grad
(
dout
,
x
,
batch_std
,
running_std
)
...
...
@@ -70,7 +70,7 @@ def get_bprop_correction_mul(self):
@
bprop_getters
.
register
(
P
.
BatchNormFold2
)
def
get_bprop_batchnorm_fold2
(
self
):
"""Generate bprop for
CorrectionAdd
"""
"""Generate bprop for
BatchNormFold2 for GPU
"""
op_f
=
P
.
BatchNormFold2Grad
(
freeze_bn
=
self
.
freeze_bn
)
def
bprop
(
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
global_step
,
out
,
dout
):
...
...
@@ -80,3 +80,48 @@ def get_bprop_batchnorm_fold2(self):
zeros_like
(
global_step
)
return
bprop
@
bprop_getters
.
register
(
P
.
BatchNormFoldD
)
def
get_bprop_BatchNormFold
(
self
):
"""Generate bprop for BatchNormFold for Ascend"""
op
=
P
.
BatchNormFoldGrad_
(
self
.
epsilon
,
self
.
is_training
,
self
.
freeze_bn
)
def
bprop
(
x
,
x_sum
,
x_square_sum
,
mean
,
variance
,
out
,
dout
):
dx
=
op
(
dout
[
1
],
dout
[
2
],
x
,
out
[
1
],
out
[
2
])
return
dx
,
zeros_like
(
x_sum
),
zeros_like
(
x_square_sum
),
zeros_like
(
mean
),
zeros_like
(
variance
)
return
bprop
@
bprop_getters
.
register
(
P
.
BNTrainingReduce
)
def
get_bprop_BNTrainingReduce
(
self
):
def
bprop
(
x
,
out
,
dout
):
return
(
zeros_like
(
x
),)
return
bprop
@
bprop_getters
.
register
(
P
.
BatchNormFold2_D
)
def
get_bprop_batchnorm_fold2_
(
self
):
"""Generate bprop for BatchNormFold2 for Ascend"""
op_reduce
=
P
.
BatchNormFold2GradReduce
(
freeze_bn
=
self
.
freeze_bn
)
op_f
=
P
.
BatchNormFold2GradD
(
freeze_bn
=
self
.
freeze_bn
)
def
bprop
(
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
out
,
dout
):
dout_reduce
,
dout_x_reduce
=
op_reduce
(
dout
,
x
)
d_batch_std
,
d_batch_mean
,
d_gamma
,
d_x
=
op_f
(
dout
,
dout_reduce
,
dout_x_reduce
,
gamma
,
batch_std
,
batch_mean
,
running_std
)
return
d_x
,
dout_reduce
,
d_gamma
,
d_batch_std
,
d_batch_mean
,
zeros_like
(
running_std
)
return
bprop
@
bprop_getters
.
register
(
P
.
FakeQuantWithMinMaxUpdate
)
def
get_bprop_fakequant_with_minmax_update
(
self
):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
return
bprop
mindspore/ops/_op_impl/_custom_op/batchnorm_fold.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
from
te
import
tvm
from
topi
import
generic
from
topi.cce
import
util
batch_norm_op_info
=
TBERegOp
(
"BatchNormFoldD"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"batchnorm_fold.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"momentum"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"epsilon"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"is_training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"freeze_bn"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"data_format"
,
"optional"
,
"str"
,
"all"
)
\
.
input
(
0
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"x_sum"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"x_square_sum"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"mean"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"variance"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"y"
,
False
,
"required"
,
"all"
)
\
.
output
(
1
,
"batch_mean"
,
False
,
"required"
,
"all"
)
\
.
output
(
2
,
"batch_std"
,
False
,
"required"
,
"all"
)
\
.
output
(
3
,
"running_mean"
,
False
,
"required"
,
"all"
)
\
.
output
(
4
,
"running_std"
,
False
,
"required"
,
"all"
)
\
.
output
(
5
,
"mean_updated"
,
False
,
"required"
,
"all"
)
\
.
output
(
6
,
"variance_updated"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
batch_norm_op_info
)
def
_batchnorm_fold_tbe
():
"""_BatchNormFold TBE register"""
return
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
float
,
float
,
bool
,
int
,
str
,
str
)
def
batchnorm_fold
(
x
,
x_sum
,
x_square_sum
,
mean
,
variance
,
y
,
batch_mean
,
batch_std
,
running_mean
,
running_std
,
mean_updated
,
variance_updated
,
momentum
=
0.9
,
epsilon
=
1e-5
,
is_training
=
True
,
freeze_bn
=
0
,
data_format
=
"NCHW"
,
kernel_name
=
"batchnorm_fold"
):
"""batchnorm_fold TBE op"""
momentum
=
1.0
-
momentum
util
.
check_kernel_name
(
kernel_name
)
data_format
=
data_format
.
upper
()
if
data_format
!=
"NCHW"
:
raise
RuntimeError
(
"The data_format only support NCHW"
)
shape_x
=
x
.
get
(
"shape"
)
shape_mean
=
mean
.
get
(
"shape"
)
shape_variance
=
variance
.
get
(
"shape"
)
dtype_x
=
x
.
get
(
"dtype"
)
dtype_mean
=
mean
.
get
(
"dtype"
)
dtype_variance
=
variance
.
get
(
"dtype"
)
for
shape
in
(
shape_x
,
shape_mean
,
shape_variance
):
util
.
check_shape_rule
(
shape
)
util
.
check_tensor_shape_size
(
shape
)
check_tuple
=
(
"float16"
,
"float32"
)
for
dtype
in
(
dtype_x
,
dtype_mean
,
dtype_variance
):
util
.
check_dtype_rule
(
dtype
.
lower
(),
check_tuple
)
format_data
=
x
.
get
(
"format"
).
upper
()
if
format_data
not
in
(
"NCHW"
,
"NC1HWC0"
):
raise
RuntimeError
(
"Format of input only support 4D and 5HD"
)
if
format_data
==
"NC1HWC0"
:
if
len
(
shape_x
)
!=
5
:
raise
RuntimeError
(
"batchnorm_fold only support shape 5D"
"when input format is NC1HWC0"
)
shape_mean
=
(
1
,
shape_x
[
1
],
1
,
1
,
shape_x
[
4
])
elif
format_data
==
"NCHW"
:
if
len
(
shape_x
)
<
2
or
len
(
shape_x
)
>
4
:
raise
RuntimeError
(
"batchnorm_fold only support shape 2D to 4D"
)
if
shape_x
[
1
]
!=
shape_mean
[
0
]:
raise
RuntimeError
(
"data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x"
)
shape_mean
=
(
1
,
shape_x
[
1
],)
for
_
in
range
(
2
,
len
(
shape_x
)):
shape_mean
=
shape_mean
+
(
1
,)
x_input
=
tvm
.
placeholder
(
shape_x
,
name
=
"x_input"
,
dtype
=
dtype_x
.
lower
())
x_sum
=
tvm
.
placeholder
(
shape_mean
,
name
=
"x_sum"
,
dtype
=
dtype_x
.
lower
())
x_square_sum
=
tvm
.
placeholder
(
shape_mean
,
name
=
"x_square_sum"
,
dtype
=
dtype_x
.
lower
())
mean
=
tvm
.
placeholder
(
shape_mean
,
name
=
"mean"
,
dtype
=
dtype_mean
.
lower
())
variance
=
tvm
.
placeholder
(
shape_mean
,
name
=
"variance"
,
dtype
=
dtype_variance
.
lower
())
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x_input
.
shape
)
num
=
shape_x
[
0
]
*
shape_x
[
2
]
*
shape_x
[
3
]
num_rec
=
1.0
/
num
# compute the mean of x
batch_mean
=
te
.
lang
.
cce
.
vmuls
(
x_sum
,
num_rec
)
# compute the variance of x
variance_div
=
te
.
lang
.
cce
.
vmuls
(
x_square_sum
,
num_rec
)
mean_square
=
te
.
lang
.
cce
.
vmul
(
batch_mean
,
batch_mean
)
batch_var_biased
=
te
.
lang
.
cce
.
vsub
(
variance_div
,
mean_square
)
if
num
==
1
:
batch_var_scaler
=
0.0
else
:
batch_var_scaler
=
float
(
num
)
/
(
num
-
1
)
batch_variance
=
te
.
lang
.
cce
.
vmuls
(
batch_var_biased
,
batch_var_scaler
)
batch_std
=
te
.
lang
.
cce
.
vsqrt
(
te
.
lang
.
cce
.
vadds
(
batch_variance
,
epsilon
))
factor
=
1.0
-
momentum
factor_reverse
=
momentum
mean_mul
=
te
.
lang
.
cce
.
vmuls
(
batch_mean
,
factor
)
mean_mul_rev
=
te
.
lang
.
cce
.
vmuls
(
mean
,
factor_reverse
)
mean_updated
=
te
.
lang
.
cce
.
vadd
(
mean_mul
,
mean_mul_rev
)
var_mul
=
te
.
lang
.
cce
.
vmuls
(
batch_variance
,
factor
)
var_mul_rev
=
te
.
lang
.
cce
.
vmuls
(
variance
,
factor_reverse
)
variance_updated
=
te
.
lang
.
cce
.
vadd
(
var_mul
,
var_mul_rev
)
y
=
te
.
lang
.
cce
.
vadds
(
x_input
,
0.0
)
running_mean
=
te
.
lang
.
cce
.
vadds
(
mean
,
0.0
)
running_std
=
te
.
lang
.
cce
.
vsqrt
(
te
.
lang
.
cce
.
vadds
(
variance
,
epsilon
))
res
=
[
y
,
batch_mean
,
batch_std
,
running_mean
,
running_std
,
mean_updated
,
variance_updated
]
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
config
=
{
"name"
:
kernel_name
,
"tensor_list"
:
[
x_input
,
x_sum
,
x_square_sum
,
mean
,
variance
]
+
res
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2 op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
batchnorm_fold2_op_info
=
TBERegOp
(
"BatchNormFold2_D"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"batchnorm_fold2.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"beta"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"gamma"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
4
,
"batch_mean"
,
None
,
"required"
,
None
)
\
.
input
(
5
,
"running_std"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
batchnorm_fold2_op_info
)
def
_batchnorm_fold2_tbe
():
"""_BatchNormFold2 TBE register"""
return
@
fusion_manager
.
register
(
"batchnorm_fold2"
)
def
batchnorm_fold2_compute
(
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
kernel_name
=
"batchnorm_fold2"
):
"""_BatchNormFold2 compute"""
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
factor
=
te
.
lang
.
cce
.
vdiv
(
running_std
,
batch_std
)
factor_b
=
te
.
lang
.
cce
.
broadcast
(
factor
,
shape_x
)
res
=
te
.
lang
.
cce
.
vmul
(
x
,
factor_b
)
bias
=
te
.
lang
.
cce
.
vdiv
(
batch_mean
,
batch_std
)
bias
=
te
.
lang
.
cce
.
vmul
(
bias
,
gamma
)
bias
=
te
.
lang
.
cce
.
vsub
(
beta
,
bias
)
bias_b
=
te
.
lang
.
cce
.
broadcast
(
bias
,
shape_x
)
res
=
te
.
lang
.
cce
.
vadd
(
res
,
bias_b
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
str
)
def
batchnorm_fold2
(
x
,
beta
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
y
,
kernel_name
=
"batchnorm_fold2"
):
"""_BatchNormFold2 op"""
shape
=
x
.
get
(
"shape"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape
)
util
.
check_shape_size
(
shape
,
SHAPE_SIZE_LIMIT
)
check_list
=
[
"float16"
,
"float32"
]
inp_dtype
=
x
.
get
(
"dtype"
).
lower
()
if
not
inp_dtype
in
check_list
:
raise
RuntimeError
(
"Dtype of input only support float16, float32"
)
data_format
=
x
.
get
(
"format"
)
ori_format
=
x
.
get
(
"ori_format"
)
if
data_format
.
upper
()
not
in
(
"NC1HWC0"
,
"NCHW"
):
raise
RuntimeError
(
"Un supported data format {}"
.
format
(
data_format
))
if
data_format
.
upper
()
==
"NCHW"
and
ori_format
!=
"NCHW"
:
raise
RuntimeError
(
"data_format(NCHW) must same as ori_format"
)
shape_c
=
gamma
.
get
(
"shape"
)
if
gamma
.
get
(
"format"
).
upper
()
==
"NCHW"
:
shape_c
=
1
,
gamma
.
get
(
"shape"
)[
0
],
1
,
1
x_t
=
tvm
.
placeholder
(
shape
,
name
=
"x"
,
dtype
=
inp_dtype
)
beta_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"beta"
,
dtype
=
inp_dtype
)
gamma_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"gamma"
,
dtype
=
inp_dtype
)
batch_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_std"
,
dtype
=
inp_dtype
)
batch_mean_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_mean"
,
dtype
=
inp_dtype
)
running_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"running_std"
,
dtype
=
inp_dtype
)
res
=
batchnorm_fold2_compute
(
x_t
,
beta_t
,
gamma_t
,
batch_std_t
,
batch_mean_t
,
running_std_t
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
[
x_t
,
beta_t
,
gamma_t
,
batch_std_t
,
batch_mean_t
,
running_std_t
,
res
]}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2Grad op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
batchnorm_fold2_grad_op_info
=
TBERegOp
(
"BatchNormFold2GradD"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"batchnorm_fold2_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2_grad"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"dout_reduce"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"dout_x_reduce"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"gamma"
,
None
,
"required"
,
None
)
\
.
input
(
4
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
5
,
"batch_mean"
,
None
,
"required"
,
None
)
\
.
input
(
6
,
"running_std"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"d_batch_std"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"d_batch_mean"
,
True
,
"required"
,
"all"
)
\
.
output
(
2
,
"d_gamma"
,
True
,
"required"
,
"all"
)
\
.
output
(
3
,
"dx"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
batchnorm_fold2_grad_op_info
)
def
_batchnorm_fold2_grad_tbe
():
"""_BatchNormFold2Grad TBE register"""
return
@
fusion_manager
.
register
(
"batchnorm_fold2_grad"
)
def
batchnorm_fold2_grad_compute
(
dout
,
dout_reduce
,
dout_x_reduce
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
kernel_name
=
"batchnorm_fold2_grad"
):
"""_BatchNormFold2Grad"""
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
dout
.
shape
)
d_batch_std_1
=
te
.
lang
.
cce
.
vmul
(
dout_reduce
,
batch_mean
)
d_batch_std_1
=
te
.
lang
.
cce
.
vmul
(
d_batch_std_1
,
gamma
)
d_batch_std_2
=
te
.
lang
.
cce
.
vmul
(
dout_x_reduce
,
running_std
)
d_batch_std
=
te
.
lang
.
cce
.
vsub
(
d_batch_std_1
,
d_batch_std_2
)
d_batch_std
=
te
.
lang
.
cce
.
vdiv
(
d_batch_std
,
batch_std
)
d_batch_std
=
te
.
lang
.
cce
.
vdiv
(
d_batch_std
,
batch_std
)
d_batch_mean
=
te
.
lang
.
cce
.
vmul
(
dout_reduce
,
gamma
)
d_batch_mean
=
te
.
lang
.
cce
.
vdiv
(
d_batch_mean
,
batch_std
)
d_batch_mean
=
te
.
lang
.
cce
.
vmuls
(
d_batch_mean
,
-
1.
)
d_gamma
=
te
.
lang
.
cce
.
vmul
(
dout_reduce
,
batch_mean
)
d_gamma
=
te
.
lang
.
cce
.
vdiv
(
d_gamma
,
batch_std
)
d_gamma
=
te
.
lang
.
cce
.
vmuls
(
d_gamma
,
-
1.
)
dx
=
te
.
lang
.
cce
.
vdiv
(
running_std
,
batch_std
)
dx
=
te
.
lang
.
cce
.
broadcast
(
dx
,
shape_x
)
dx
=
te
.
lang
.
cce
.
vmul
(
dx
,
dout
)
return
[
d_batch_std
,
d_batch_mean
,
d_gamma
,
dx
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
str
)
def
batchnorm_fold2_grad
(
dout
,
dout_reduce
,
dout_x_reduce
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
d_batch_std
,
d_batch_mean
,
d_gamma
,
dx
,
kernel_name
=
"batchnorm_fold2_grad"
):
"""_BatchNormFold2Grad op """
shape
=
dout
.
get
(
"shape"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape
)
util
.
check_shape_size
(
shape
,
SHAPE_SIZE_LIMIT
)
check_list
=
[
"float16"
,
"float32"
]
inp_dtype
=
dout
.
get
(
"dtype"
).
lower
()
if
not
inp_dtype
in
check_list
:
raise
RuntimeError
(
"Dtype of input only support float16, float32"
)
data_format
=
dout
.
get
(
"format"
)
ori_format
=
dout
.
get
(
"ori_format"
)
if
data_format
.
upper
()
not
in
(
"NC1HWC0"
,
"NCHW"
):
raise
RuntimeError
(
"Un supported data format {}"
.
format
(
data_format
))
if
data_format
.
upper
()
==
"NCHW"
and
ori_format
!=
"NCHW"
:
raise
RuntimeError
(
"data_format(NCHW) must same as ori_format"
)
shape_c
=
gamma
.
get
(
"shape"
)
if
gamma
.
get
(
"format"
).
upper
()
==
"NCHW"
:
shape_c
=
1
,
gamma
.
get
(
"shape"
)[
0
],
1
,
1
dout_t
=
tvm
.
placeholder
(
shape
,
name
=
"dout"
,
dtype
=
inp_dtype
)
dout_reduce_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"dout_reduce"
,
dtype
=
inp_dtype
)
dout_x_reduce_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"dout_x_reduce"
,
dtype
=
inp_dtype
)
gamma_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"gamma"
,
dtype
=
inp_dtype
)
batch_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_std"
,
dtype
=
inp_dtype
)
batch_mean_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_mean"
,
dtype
=
inp_dtype
)
running_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"running_std"
,
dtype
=
inp_dtype
)
res_list
=
batchnorm_fold2_grad_compute
(
dout_t
,
dout_reduce_t
,
dout_x_reduce_t
,
gamma_t
,
batch_std_t
,
batch_mean_t
,
running_std_t
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
dout_t
,
dout_reduce_t
,
dout_x_reduce_t
,
gamma_t
,
batch_std_t
,
batch_mean_t
,
running_std_t
]
+
list
(
res_list
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/batchnorm_fold2_grad_reduce.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFold2GradReduce op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
te.platform.cce_build
import
build_config
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
batchnorm_fold2_grad_reduce_op_info
=
TBERegOp
(
"BatchNormFold2GradReduce"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"batchnorm_fold2_grad_reduce.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold2_grad_reduce"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dout_reduce"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"dout_x_reduce"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
batchnorm_fold2_grad_reduce_op_info
)
def
_batchnorm_fold2_grad_reduce_tbe
():
"""_BatchNormFold2GradReduce TBE register"""
return
@
fusion_manager
.
register
(
"batchnorm_fold2_grad_reduce"
)
def
batchnorm_fold2_grad_reduce_compute
(
dout
,
x
,
dout_args
,
kernel_name
=
"batchnorm_fold2_grad_reduce"
):
"""_BatchNormFold2GradReduce compute"""
dtype
=
dout_args
.
get
(
"dtype"
)
dout_format
=
dout_args
.
get
(
"format"
)
ori_format
=
dout_args
.
get
(
"ori_format"
)
shape
=
dout_args
.
get
(
"shape"
)
if
dtype
==
"float16"
:
dout
=
te
.
lang
.
cce
.
cast_to
(
dout
,
"float32"
)
x
=
te
.
lang
.
cce
.
cast_to
(
x
,
"float32"
)
dout_x
=
te
.
lang
.
cce
.
vmul
(
dout
,
x
)
if
dout_format
==
"NC1HWC0"
:
axis
=
[
0
,
2
,
3
]
dout_reduce
,
dout_x_reduce
=
te
.
lang
.
cce
.
tuple_sum
([
dout
,
dout_x
],
axis
,
True
)
else
:
axis
=
list
(
range
(
len
(
shape
)))
if
ori_format
==
"NCHW"
:
axis
.
pop
(
1
)
for
_
,
i
in
enumerate
(
range
(
len
(
shape
))):
if
shape
[
i
]
==
1
and
i
in
axis
:
axis
.
remove
(
i
)
dout_reduce
=
te
.
lang
.
cce
.
sum
(
dout
,
axis
,
False
)
dout_x_reduce
=
te
.
lang
.
cce
.
sum
(
dout_x
,
axis
,
False
)
return
[
dout_reduce
,
dout_x_reduce
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
str
)
def
batchnorm_fold2_grad_reduce
(
dout
,
x
,
dout_reduce
,
dout_x_reduce
,
kernel_name
=
"batchnorm_fold2_grad_reduce"
):
"""_BatchNormFold2GradReduce op"""
shape
=
x
.
get
(
"shape"
)
x_format
=
x
.
get
(
"format"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape
)
util
.
check_shape_size
(
shape
,
SHAPE_SIZE_LIMIT
)
check_list
=
[
"float16"
,
"float32"
]
inp_dtype
=
x
.
get
(
"dtype"
).
lower
()
if
not
inp_dtype
in
check_list
:
raise
RuntimeError
(
"Dtype of input only support float16, float32"
)
dout_t
=
tvm
.
placeholder
(
shape
,
name
=
"dout"
,
dtype
=
inp_dtype
)
x_t
=
tvm
.
placeholder
(
shape
,
name
=
"x"
,
dtype
=
inp_dtype
)
res_list
=
batchnorm_fold2_grad_reduce_compute
(
dout_t
,
x_t
,
dout
,
kernel_name
)
if
x_format
==
"NC1HWC0"
:
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
dout_t
,
x_t
]
+
list
(
res_list
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
return
from
impl.bn_training_reduce
import
bn_training_reduce_schedule_nd
sch
,
tensor_list
=
bn_training_reduce_schedule_nd
(
res_list
)
with
build_config
:
tvm
.
build
(
sch
,
tensor_list
,
"cce"
,
name
=
kernel_name
)
mindspore/ops/_op_impl/_custom_op/batchnorm_fold_grad.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""_BatchNormFoldGrad op"""
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
import
te.lang.cce
from
te
import
tvm
from
topi
import
generic
from
topi.cce
import
util
batch_norm_op_info
=
TBERegOp
(
"BatchNormFoldGradD"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"batchnorm_fold_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"batchnorm_fold_grad"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"epsilon"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"is_training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"freeze_bn"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"d_batch_mean"
,
False
,
"required"
,
"all"
)
\
.
input
(
1
,
"d_batch_std"
,
False
,
"required"
,
"all"
)
\
.
input
(
2
,
"x"
,
False
,
"required"
,
"all"
)
\
.
input
(
3
,
"batch_mean"
,
False
,
"required"
,
"all"
)
\
.
input
(
4
,
"batch_std"
,
False
,
"required"
,
"all"
)
\
.
output
(
0
,
"dx"
,
False
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F16_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F16_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
batch_norm_op_info
)
def
_batchnorm_fold_grad_tbe
():
"""_BatchNormFoldGrad TBE register"""
return
def
_batchnorm_fold_grad_compute
(
d_batch_mean
,
d_batch_std
,
data_x
,
batch_mean
,
batch_std
):
"""_batchnorm_fold_grad_compute """
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
data_x
.
shape
)
normal_size
=
shape_x
[
0
]
*
shape_x
[
2
]
*
shape_x
[
3
]
d_batch_mean_broad
=
te
.
lang
.
cce
.
broadcast
(
d_batch_mean
,
shape_x
)
d_batch_std_broad
=
te
.
lang
.
cce
.
broadcast
(
d_batch_std
,
shape_x
)
batch_mean_broad
=
te
.
lang
.
cce
.
broadcast
(
batch_mean
,
shape_x
)
batch_std_broad
=
te
.
lang
.
cce
.
broadcast
(
batch_std
,
shape_x
)
dx
=
te
.
lang
.
cce
.
vsub
(
data_x
,
batch_mean_broad
)
dx
=
te
.
lang
.
cce
.
vmul
(
dx
,
d_batch_std_broad
)
dx
=
te
.
lang
.
cce
.
vdiv
(
dx
,
batch_std_broad
)
dx
=
te
.
lang
.
cce
.
vadd
(
dx
,
d_batch_mean_broad
)
dx
=
te
.
lang
.
cce
.
vmuls
(
dx
,
tvm
.
const
(
1.
/
normal_size
,
dtype
=
dx
.
dtype
))
return
[
dx
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
float
,
bool
,
int
,
str
)
def
batchnorm_fold_grad
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
,
dx
,
epsilon
=
1e-5
,
is_training
=
True
,
freeze_bn
=
0
,
kernel_name
=
"batchnorm_fold_grad"
):
"""batchnorm_fold_grad op """
util
.
check_kernel_name
(
kernel_name
)
for
iv
in
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
):
util
.
check_shape_rule
(
iv
.
get
(
"shape"
))
util
.
check_tensor_shape_size
(
iv
.
get
(
"shape"
))
check_tuple
=
(
"float16"
,
"float32"
)
for
iv
in
(
d_batch_mean
,
d_batch_std
,
x
,
batch_mean
,
batch_std
):
util
.
check_dtype_rule
(
iv
.
get
(
"dtype"
).
lower
(),
check_tuple
)
shape_x
=
x
.
get
(
"shape"
)
dtype_x
=
x
.
get
(
"dtype"
)
format_data
=
x
.
get
(
"format"
).
upper
()
if
format_data
not
in
(
"NCHW"
,
"NC1HWC0"
):
raise
RuntimeError
(
"Format of input only support 4D and 5HD"
)
shape_mean
=
d_batch_mean
.
get
(
"shape"
)
dtype_mean
=
d_batch_mean
.
get
(
"dtype"
).
lower
()
if
format_data
==
"NC1HWC0"
:
if
len
(
shape_x
)
!=
5
:
raise
RuntimeError
(
"batchnorm_fold only support shape 5D"
"when input format is NC1HWC0"
)
shape_mean
=
(
1
,
shape_x
[
1
],
1
,
1
,
shape_x
[
4
])
elif
format_data
==
"NCHW"
:
if
len
(
shape_x
)
<
2
or
len
(
shape_x
)
>
4
:
raise
RuntimeError
(
"batchnorm_fold only support shape 2D to 4D"
)
if
shape_x
[
1
]
!=
shape_mean
[
0
]:
raise
RuntimeError
(
"data_format is NCHW, shape_bias must"
"be equal to the second axis of shape_x"
)
shape_mean
=
(
1
,
shape_x
[
1
],)
for
_
in
range
(
2
,
len
(
shape_x
)):
shape_mean
=
shape_mean
+
(
1
,)
d_batch_mean
=
tvm
.
placeholder
(
shape_mean
,
name
=
"d_batch_mean"
,
dtype
=
dtype_mean
)
d_batch_std
=
tvm
.
placeholder
(
shape_mean
,
name
=
"d_batch_std"
,
dtype
=
dtype_mean
)
data_x
=
tvm
.
placeholder
(
shape_x
,
name
=
"data_x"
,
dtype
=
dtype_x
.
lower
())
batch_mean
=
tvm
.
placeholder
(
shape_mean
,
name
=
"batch_mean"
,
dtype
=
dtype_mean
)
batch_std
=
tvm
.
placeholder
(
shape_mean
,
name
=
"batch_std"
,
dtype
=
dtype_mean
)
res
=
_batchnorm_fold_grad_compute
(
d_batch_mean
,
d_batch_std
,
data_x
,
batch_mean
,
batch_std
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
d_batch_mean
,
d_batch_std
,
data_x
,
batch_mean
,
batch_std
]
+
res
config
=
{
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/correction_mul.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CorrectionMul op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
correction_mul_op_info
=
TBERegOp
(
"CorrectionMul"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"correction_mul.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"running_std"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
correction_mul_op_info
)
def
_correction_mul_tbe
():
"""CorrectionMul TBE register"""
return
@
fusion_manager
.
register
(
"correction_mul"
)
def
correction_mul_compute
(
x
,
batch_std
,
running_std
,
kernel_name
=
"correction_mul"
):
"""CorrectionMul compute"""
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
factor
=
te
.
lang
.
cce
.
vdiv
(
batch_std
,
running_std
)
factor_b
=
te
.
lang
.
cce
.
broadcast
(
factor
,
shape_x
)
res
=
te
.
lang
.
cce
.
vmul
(
x
,
factor_b
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
int
,
str
)
def
correction_mul
(
x
,
batch_std
,
running_std
,
y
,
channel
,
kernel_name
=
"correction_mul"
):
"""CorrectionMul op"""
shape
=
x
.
get
(
"shape"
)
data_format
=
x
.
get
(
"format"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape
)
util
.
check_shape_size
(
shape
,
SHAPE_SIZE_LIMIT
)
check_list
=
[
"float16"
,
"float32"
]
inp_dtype
=
x
.
get
(
"dtype"
).
lower
()
if
not
inp_dtype
in
check_list
:
raise
RuntimeError
(
"Dtype of input only support float16, float32"
)
# shape = util.shape_refine(shape)
x_t
=
tvm
.
placeholder
(
shape
,
name
=
"x"
,
dtype
=
inp_dtype
)
shape_c
=
[
1
]
*
len
(
shape
)
shape_c
[
channel
]
=
batch_std
.
get
(
"ori_shape"
)[
0
]
if
data_format
==
"NC1HWC0"
and
channel
==
1
:
shape_c
=
batch_std
.
get
(
"shape"
)
batch_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_std"
,
dtype
=
inp_dtype
)
running_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"running_std"
,
dtype
=
inp_dtype
)
res
=
correction_mul_compute
(
x_t
,
batch_std_t
,
running_std_t
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
[
x_t
,
batch_std_t
,
running_std_t
,
res
]}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CorrectionMul op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
correction_mul_grad_op_info
=
TBERegOp
(
"CorrectionMulGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"correction_mul_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"correction_mul_grad"
)
\
.
partial_flag
(
True
)
\
.
op_pattern
(
"formatAgnostic"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"batch_std"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"running_std"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dx"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"d_batch_std"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
correction_mul_grad_op_info
)
def
_correction_mul_grad_tbe
():
"""CorrectionMulGrad TBE register"""
return
@
fusion_manager
.
register
(
"correction_mul_grad"
)
def
correction_mul_grad_compute
(
dout
,
x
,
batch_std
,
running_std
,
channel
,
data_format
,
kernel_name
=
"correction_mul"
):
"""CorrectionMulGrad compute"""
shape_x
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
factor
=
te
.
lang
.
cce
.
vdiv
(
batch_std
,
running_std
)
factor_b
=
te
.
lang
.
cce
.
broadcast
(
factor
,
shape_x
)
dx
=
te
.
lang
.
cce
.
vmul
(
dout
,
factor_b
)
mul_data
=
te
.
lang
.
cce
.
vmul
(
dout
,
x
)
if
channel
==
0
:
if
data_format
==
"NCHW"
:
axis
=
[
1
,
2
,
3
]
else
:
axis
=
[
1
,
2
,
3
,
4
]
else
:
axis
=
[
2
,
3
]
red_data
=
te
.
lang
.
cce
.
sum
(
mul_data
,
axis
,
keepdims
=
True
)
d_batch_std
=
te
.
lang
.
cce
.
vdiv
(
red_data
,
running_std
)
return
[
dx
,
d_batch_std
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
dict
,
int
,
str
)
def
correction_mul_grad
(
dout
,
x
,
batch_std
,
running_std
,
dx
,
d_batch_std
,
channel
,
kernel_name
=
"correction_mul_grad"
):
"""CorrectionMulGrad op"""
shape_dout
=
dout
.
get
(
"shape"
)
shape_x
=
dout
.
get
(
"shape"
)
dtype_dout
=
dout
.
get
(
"dtype"
)
dtype_x
=
x
.
get
(
"dtype"
)
dtype_batch_std
=
batch_std
.
get
(
"dtype"
)
dtype_running_std
=
running_std
.
get
(
"dtype"
)
inp_dtype_dout
=
dtype_dout
.
lower
()
inp_dtype_x
=
dtype_x
.
lower
()
inp_dtype_batch_std
=
dtype_batch_std
.
lower
()
inp_dtype_running_std
=
dtype_running_std
.
lower
()
util
.
check_dtype_rule
(
inp_dtype_dout
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_x
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_batch_std
,
(
"float32"
,))
util
.
check_dtype_rule
(
inp_dtype_running_std
,
(
"float32"
,))
util
.
compare_tensor_dict_key
(
dout
,
x
,
"dtype"
)
util
.
compare_tensor_dict_key
(
dout
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
dx
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
batch_std
,
running_std
,
"shape"
)
util
.
compare_tensor_dict_key
(
batch_std
,
d_batch_std
,
"shape"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
shape_x
)
util
.
check_shape_size
(
shape_x
,
SHAPE_SIZE_LIMIT
)
data_format
=
dout
.
get
(
"format"
)
ori_format
=
dout
.
get
(
"format"
)
if
data_format
.
upper
()
not
in
(
"NC1HWC0"
,
"NCHW"
):
raise
RuntimeError
(
"Un supported data format {}"
.
format
(
data_format
))
if
data_format
.
upper
()
==
"NCHW"
and
ori_format
!=
"NCHW"
:
raise
RuntimeError
(
"data_format(NCHW) must same as ori_format"
)
shape_c
=
[
1
]
*
len
(
shape_x
)
shape_c
[
channel
]
=
batch_std
.
get
(
"ori_shape"
)[
0
]
if
data_format
==
"NC1HWC0"
and
channel
==
1
:
shape_c
=
batch_std
.
get
(
"shape"
)
dout_t
=
tvm
.
placeholder
(
shape_dout
,
name
=
"dout"
,
dtype
=
inp_dtype_dout
)
x_t
=
tvm
.
placeholder
(
shape_x
,
name
=
"x"
,
dtype
=
inp_dtype_x
)
batch_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"batch_std"
,
dtype
=
inp_dtype_batch_std
)
running_std_t
=
tvm
.
placeholder
(
shape_c
,
name
=
"running_std"
,
dtype
=
inp_dtype_running_std
)
res_list
=
correction_mul_grad_compute
(
dout_t
,
x_t
,
batch_std_t
,
running_std_t
,
channel
,
data_format
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
dout_t
,
x_t
,
batch_std_t
,
running_std_t
]
+
list
(
res_list
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMax op"""
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_op_info
=
TBERegOp
(
"FakeQuantWithMinMax"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_with_min_max_vars_ema.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_with_min_max_vars_ema"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
fake_quant_op_info
)
def
_fake_quant_tbe
():
"""FakeQuantWithMinMax TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_with_min_max_vars_ema"
)
def
fake_quant_with_min_max_vars_ema_compute
(
x
,
min_val
,
max_val
,
y
,
quant_min
,
quant_max
,
kernel_name
=
"correction_mul"
):
"""FakeQuantWithMinMax"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
shape_min
,
x
.
dtype
)
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
shape_min
,
x
.
dtype
)
min_val
=
te
.
lang
.
cce
.
broadcast
(
min_val
,
shape_min
,
x
.
dtype
)
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
nudge_zp
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
)))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
# boradcast to shape
nudge_min
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
shape
,
x
.
dtype
)
nudge_max
=
te
.
lang
.
cce
.
broadcast
(
nudge_max
,
shape
,
x
.
dtype
)
scale
=
te
.
lang
.
cce
.
broadcast
(
scale
,
shape
,
x
.
dtype
)
# FakeQuant
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max
,
te
.
lang
.
cce
.
vmax
(
nudge_min
,
x
))
nudge_input
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min
),
scale
),
0.5
))
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale
),
nudge_min
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_with_min_max_vars_ema
(
x
,
min_val
,
max_val
,
y
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
quant_delay
,
kernel_name
=
"fake_quant"
):
"""FakeQuantWithMinMax"""
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
min_shape
=
util
.
scalar2tensor_one
(
min_shape
)
max_shape
=
util
.
scalar2tensor_one
(
max_shape
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
input_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
1
)
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
1
)
util
.
check_tensor_shape_size
(
input_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
"float16"
]
x_dtype
=
input_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
input_shape
=
(
functools_reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[:]),)
shape_min
,
_
,
_
=
util
.
produce_shapes
(
min_shape
,
input_shape
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res
=
fake_quant_with_min_max_vars_ema_compute
(
input_data
,
min_data
,
max_data
,
y
,
quant_min
,
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
input_data
,
min_data
,
max_data
,
res
]
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_grad.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxGrad op"""
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
D_TYPE
=
'float32'
fake_quant_grad_op_info
=
TBERegOp
(
"FakeQuantWithMinMaxGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_with_min_max_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_with_min_max_grad"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dx"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
def
_less_compare_float32
(
data_x
,
data_y
):
"""_less_compare_float32 compute"""
shape_inputs
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
data_x
.
shape
)
min_value
=
tvm
.
const
(
2
**
(
-
126
),
dtype
=
D_TYPE
)
max_value
=
tvm
.
const
(
2
**
62
,
dtype
=
D_TYPE
)
factor_value
=
tvm
.
const
(
2
**
2
,
dtype
=
D_TYPE
)
data_zero
=
te
.
lang
.
cce
.
broadcast
(
tvm
.
const
(
0
,
dtype
=
D_TYPE
),
shape_inputs
,
D_TYPE
)
min_value_tensor
=
te
.
lang
.
cce
.
vadds
(
data_zero
,
min_value
)
res_sub
=
te
.
lang
.
cce
.
vsub
(
data_y
,
data_x
)
res_min
=
te
.
lang
.
cce
.
vmin
(
res_sub
,
min_value_tensor
)
res_max
=
te
.
lang
.
cce
.
vmax
(
res_min
,
data_zero
)
res_max_mul
=
te
.
lang
.
cce
.
vmuls
(
res_max
,
max_value
)
res_max_mul_max
=
te
.
lang
.
cce
.
vmuls
(
res_max_mul
,
max_value
)
res
=
te
.
lang
.
cce
.
vmuls
(
res_max_mul_max
,
factor_value
)
return
res
@
op_info_register
(
fake_quant_grad_op_info
)
def
_fake_quant_grad_tbe
():
"""FakeQuantWithMinMaxGrad TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_with_min_max_grad"
)
def
fake_quant_with_min_max_grad_compute
(
dout
,
x
,
min_val
,
max_val
,
quant_min
,
quant_max
,
kernel_name
=
"fake_quant_with_min_max_grad"
):
"""FakeQuantWithMinMaxGrad"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
tvm
.
const
(
quant_min
,
x
.
dtype
)
quant_max
=
tvm
.
const
(
quant_max
,
x
.
dtype
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
shape_min
)
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
shape_min
)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
nudge_zp
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
)))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
nudge_min
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
shape
)
nudge_max
=
te
.
lang
.
cce
.
broadcast
(
nudge_max
,
shape
)
bool_over_min
=
_less_compare_float32
(
nudge_min
,
x
)
bool_less_max
=
_less_compare_float32
(
x
,
nudge_max
)
bool_between
=
te
.
lang
.
cce
.
vmul
(
bool_over_min
,
bool_less_max
)
res
=
te
.
lang
.
cce
.
vmul
(
dout
,
bool_between
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
int
,
int
,
str
)
def
fake_quant_with_min_max_grad
(
dout
,
x
,
min_val
,
max_val
,
dx
,
num_bits
,
quant_delay
,
kernel_name
=
"fake_quant_with_min_max_grad"
):
"""FakeQuantWithMinMaxGrad"""
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
min_shape
=
util
.
scalar2tensor_one
(
min_shape
)
max_shape
=
util
.
scalar2tensor_one
(
max_shape
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
input_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
1
)
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
1
)
util
.
check_tensor_shape_size
(
input_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
'float16'
]
x_dtype
=
input_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
input_shape
=
(
functools_reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[:]),)
shape_min
,
_
,
_
=
util
.
produce_shapes
(
min_shape
,
input_shape
)
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
dout_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res
=
fake_quant_with_min_max_grad_compute
(
dout_data
,
input_data
,
min_data
,
max_data
,
quant_min
,
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
dout_data
,
input_data
,
min_data
,
max_data
,
res
]
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max_update.py
0 → 100644
浏览文件 @
0a52fd05
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantWithMinMaxUpdate op"""
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_update5d_op_info
=
TBERegOp
(
"FakeQuantWithMinMaxUpdate"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_with_min_max_update5d.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_with_min_max_update"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"min_up"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"max_up"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
fake_quant_update5d_op_info
)
def
_fake_quant_update5d_tbe
():
"""_FakeQuantWithMinMaxUpdate5D TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_with_min_max_update"
)
def
fake_quant_with_min_max_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
=
"fake_quant_update"
):
"""FakeQuantWithMinMaxUpdate compute"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
min_val
=
te
.
lang
.
cce
.
broadcast
(
min_val
,
shape_min
,
x
.
dtype
)
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
axis
=
tuple
(
range
(
len
(
shape
)))
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_with_min_max_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
quant_delay
,
kernel_name
=
"fake_quant_update"
):
"""FakeQuantWithMinMax op"""
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
min_shape
=
util
.
scalar2tensor_one
(
min_shape
)
max_shape
=
util
.
scalar2tensor_one
(
max_shape
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
input_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
1
)
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
1
)
util
.
check_tensor_shape_size
(
input_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
"float16"
]
x_dtype
=
input_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
input_shape
=
(
functools_reduce
(
lambda
x
,
y
:
x
*
y
,
input_shape
[:]),)
shape_min
,
_
,
_
=
util
.
produce_shapes
(
min_shape
,
input_shape
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res_list
=
fake_quant_with_min_max_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
input_data
,
min_data
,
max_data
]
+
list
(
res_list
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/operations/_quant_ops.py
浏览文件 @
0a52fd05
此差异已折叠。
点击以展开。
tests/ut/python/train/quant/test_quant.py
浏览文件 @
0a52fd05
...
...
@@ -17,13 +17,12 @@ import numpy as np
from
mobilenetv2_combined
import
MobileNetV2
import
mindspore.context
as
context
import
mindspore.ops.operations
as
P
from
mindspore
import
Tensor
from
mindspore
import
nn
from
mindspore.nn.layer
import
combined
from
mindspore.train.quant
import
quant
as
qat
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
LeNet5
(
nn
.
Cell
):
...
...
@@ -65,7 +64,7 @@ class LeNet5(nn.Cell):
x
=
self
.
fc3
(
x
)
return
x
"""
def test_qat_lenet():
net = LeNet5()
net = qat.convert_quant_network(
...
...
@@ -93,3 +92,4 @@ def test_qat_mobile_train():
net = nn.WithLossCell(net, loss)
net = nn.TrainOneStepCell(net, optimizer)
net(img, label)
"""
\ No newline at end of file
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录