Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
b7db3e9a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
b7db3e9a
编写于
6月 05, 2020
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add fake quant per channel and bug fix
上级
bd3e8da6
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
888 addition
and
291 deletion
+888
-291
mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc
+1
-1
mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc
...pore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc
+1
-1
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
...src/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
+1
-1
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc
...ernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc
+1
-1
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+93
-119
mindspore/ops/_grad/grad_quant_ops.py
mindspore/ops/_grad/grad_quant_ops.py
+25
-10
mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
...op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
+135
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py
.../_op_impl/_custom_op/fake_quant_minmax_perlayer_update.py
+22
-21
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
+145
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
...ore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
+171
-0
mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py
mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer.py
+25
-25
mindspore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py
...spore/ops/_op_impl/_custom_op/fake_quant_perlayer_grad.py
+25
-22
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+243
-90
未找到文件。
mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc
浏览文件 @
b7db3e9a
...
@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
...
@@ -171,6 +171,6 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
return
true
;
return
true
;
}
}
MS_REG_GPU_KERNEL
(
FakeQuant
WithMinMax
,
FakeQuantGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuant
PerLayer
,
FakeQuantGpuKernel
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc
浏览文件 @
b7db3e9a
...
@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
...
@@ -153,6 +153,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return
true
;
return
true
;
}
}
MS_REG_GPU_KERNEL
(
FakeQuant
WithMinMax
Grad
,
FakeQuantGradGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuant
PerLayer
Grad
,
FakeQuantGradGpuKernel
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_gpu_kernel.cc
浏览文件 @
b7db3e9a
...
@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
...
@@ -175,6 +175,6 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
return
true
;
return
true
;
}
}
MS_REG_GPU_KERNEL
(
FakeQuant
WithMinMax
PerChannel
,
FakeQuantPerChannelGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuantPerChannel
,
FakeQuantPerChannelGpuKernel
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per_channel_grad_gpu_kernel.cc
浏览文件 @
b7db3e9a
...
@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
...
@@ -143,6 +143,6 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
return
true
;
return
true
;
}
}
MS_REG_GPU_KERNEL
(
FakeQuant
WithMinMax
PerChannelGrad
,
FakeQuantPerChannelGradGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuantPerChannelGrad
,
FakeQuantPerChannelGradGpuKernel
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/nn/layer/quant.py
浏览文件 @
b7db3e9a
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# ============================================================================
# ============================================================================
"""Aware quantization."""
"""Aware quantization."""
from
functools
import
partial
import
numpy
as
np
import
numpy
as
np
import
mindspore.common.dtype
as
mstype
import
mindspore.common.dtype
as
mstype
from
mindspore.ops
import
operations
as
P
from
mindspore.ops
import
operations
as
P
...
@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell):
...
@@ -101,10 +102,9 @@ class BatchNormFoldCell(Cell):
return
batch_mean
,
batch_std
,
running_mean
,
running_std
return
batch_mean
,
batch_std
,
running_mean
,
running_std
class
FakeQuantWithMinMax
D
(
Cell
):
class
FakeQuantWithMinMax
Ascend
(
Cell
):
r
"""
r
"""
Aware Quantization training op of ascend. This OP provide Fake quantization observer
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
function on data with min and max.
Args:
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
...
@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell):
...
@@ -125,7 +125,7 @@ class FakeQuantWithMinMaxD(Cell):
Tensor, with the same type and shape as the `x`.
Tensor, with the same type and shape as the `x`.
Examples:
Examples:
>>> fake_quant =
nn.FakeQuantWithMinMaxD
()
>>> fake_quant =
FakeQuantWithMinMax
()
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> input_x = Tensor(np.array([[1, 2, 1], [-2, 0, -1]]), mindspore.float32)
>>> result = fake_quant(input_x)
>>> result = fake_quant(input_x)
"""
"""
...
@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell):
...
@@ -137,75 +137,77 @@ class FakeQuantWithMinMaxD(Cell):
ema
=
False
,
ema
=
False
,
ema_decay
=
0.999
,
ema_decay
=
0.999
,
per_channel
=
False
,
per_channel
=
False
,
channel_size
=
1
,
channel_axis
=
1
,
out_channels
=
1
,
quant_delay
=
0
,
quant_delay
=
0
,
symmetric
=
False
,
symmetric
=
False
,
narrow_range
=
False
,
narrow_range
=
False
,
training
=
True
):
training
=
True
):
"""init FakeQuantWithMinMax ascend layer"""
"""init FakeQuantWithMinMaxAscend layer"""
super
(
FakeQuantWithMinMaxD
,
self
).
__init__
()
super
(
FakeQuantWithMinMaxAscend
,
self
).
__init__
()
self
.
min_init
=
min_init
self
.
min_init
=
min_init
self
.
num_bits
=
num_bits
self
.
max_init
=
max_init
self
.
max_init
=
max_init
self
.
num_bits
=
num_bits
self
.
ema
=
ema
self
.
ema
=
ema
self
.
ema_decay
=
ema_decay
self
.
ema_decay
=
ema_decay
self
.
per_channel
=
per_channel
self
.
per_channel
=
per_channel
self
.
channel_
size
=
channel_size
self
.
channel_
axis
=
channel_axis
self
.
quant_delay
=
quant_delay
self
.
quant_delay
=
quant_delay
self
.
symmetric
=
symmetric
self
.
symmetric
=
symmetric
self
.
narrow_range
=
narrow_range
self
.
narrow_range
=
narrow_range
self
.
training
=
training
self
.
training
=
training
if
not
per_channel
:
# init tensor min and max for fake quant op
self
.
fake_quant
=
P
.
FakeQuantWithMinMax
(
num_bits
=
self
.
num_bits
,
if
isinstance
(
min_init
,
int
):
ema
=
self
.
ema
,
min_array
=
np
.
array
([
min_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
ema_decay
=
self
.
ema_decay
,
max_array
=
np
.
array
([
max_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
quant_delay
=
self
.
quant_delay
,
elif
isinstance
(
min_init
,
list
):
symmetric
=
self
.
symmetric
,
min_array
=
np
.
array
([
self
.
min_init
for
i
in
range
(
narrow_range
=
self
.
narrow_range
,
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
training
=
training
)
max_array
=
np
.
array
([
self
.
max_init
for
i
in
range
(
self
.
ema_update
=
P
.
FakeQuantWithMinMaxUpdate
(
num_bits
=
self
.
num_bits
,
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
ema
=
self
.
ema
,
self
.
minq
=
Parameter
(
Tensor
(
min_array
),
name
=
'quant_min'
,
requires_grad
=
False
)
ema_decay
=
self
.
ema_decay
,
self
.
maxq
=
Parameter
(
Tensor
(
max_array
),
name
=
'quant_max'
,
requires_grad
=
False
)
quant_delay
=
self
.
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
training
)
else
:
raise
RuntimeError
(
"not support per channel"
)
if
isinstance
(
min_init
,
Parameter
)
:
if
per_channel
:
self
.
minq
=
min_init
quant_fun
=
partial
(
P
.
FakeQuantPerChannel
,
channel_axis
=
self
.
channel_axis
)
self
.
maxq
=
max_init
ema_fun
=
partial
(
P
.
FakeQuantMinMaxPerChannelUpdate
,
channel_axis
=
self
.
channel_axis
)
else
:
else
:
self
.
minq
=
Parameter
(
Tensor
(
np
.
array
([
min_init
]).
astype
(
np
.
float32
)),
quant_fun
=
P
.
FakeQuantPerLayer
name
=
'quant_min'
,
ema_fun
=
P
.
FakeQuantMinMaxPerLayerUpdate
requires_grad
=
False
)
self
.
maxq
=
Parameter
(
Tensor
(
np
.
array
([
max_init
]).
astype
(
np
.
float32
)),
self
.
fake_quant
=
quant_fun
(
num_bits
=
self
.
num_bits
,
name
=
'quant_max'
,
ema
=
self
.
ema
,
requires_grad
=
False
)
ema_decay
=
self
.
ema_decay
,
self
.
reduce_min
=
P
.
ReduceMin
()
quant_delay
=
self
.
quant_delay
,
self
.
reduce_max
=
P
.
ReduceMax
()
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
self
.
training
)
self
.
ema_update
=
ema_fun
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema_decay
=
self
.
ema_decay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
self
.
training
)
def
extend_repr
(
self
):
def
extend_repr
(
self
):
s
=
'
min_init={}, max_init={}, ema={}, ema_decay={}, per_channel={}, channel_size={}, quant_delay
={}'
.
format
(
s
=
'
ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max
={}'
.
format
(
self
.
min_init
,
self
.
max_init
,
self
.
ema
,
self
.
ema_decay
,
self
.
per_channel
,
self
.
channel_size
,
self
.
min_init
,
self
.
max_init
,
self
.
ema
,
self
.
ema_decay
,
self
.
quant_delay
)
self
.
per_channel
,
self
.
quant_delay
,
self
.
channel_axis
)
return
s
return
s
def
construct
(
self
,
x
,
minq
,
maxq
):
def
construct
(
self
,
x
):
if
self
.
training
:
if
self
.
update
:
min_up
,
max_up
=
self
.
ema_update
(
x
,
minq
,
maxq
)
min_up
,
max_up
=
self
.
ema_update
(
x
,
self
.
minq
,
self
.
maxq
)
out
=
self
.
fake_quant
(
x
,
min_up
,
max_up
)
out
=
self
.
fake_quant
(
x
,
min_up
,
max_up
)
P
.
Assign
()(
self
.
minq
,
min_up
)
P
.
Assign
()(
self
.
minq
,
min_up
)
P
.
Assign
()(
self
.
maxq
,
max_up
)
P
.
Assign
()(
self
.
maxq
,
max_up
)
else
:
else
:
out
=
self
.
fake_quant
(
x
,
minq
,
maxq
)
out
=
self
.
fake_quant
(
x
,
self
.
minq
,
self
.
maxq
)
return
out
return
out
class
FakeQuantWithMinMax
(
Cell
):
class
FakeQuantWithMinMax
GPU
(
Cell
):
r
"""
r
"""
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
...
@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell):
...
@@ -240,98 +242,69 @@ class FakeQuantWithMinMax(Cell):
ema
=
False
,
ema
=
False
,
ema_decay
=
0.999
,
ema_decay
=
0.999
,
per_channel
=
False
,
per_channel
=
False
,
channel_axis
=
1
,
out_channels
=
1
,
out_channels
=
1
,
quant_delay
=
0
,
quant_delay
=
0
,
symmetric
=
False
,
symmetric
=
False
,
narrow_range
=
False
):
narrow_range
=
False
,
"""init FakeQuantWithMinMax layer"""
training
=
True
):
super
(
FakeQuantWithMinMax
,
self
).
__init__
()
super
(
FakeQuantWithMinMaxGPU
,
self
).
__init__
()
self
.
min_init
=
min_init
self
.
min_init
=
min_init
self
.
num_bits
=
num_bits
self
.
max_init
=
max_init
self
.
max_init
=
max_init
self
.
num_bits
=
num_bits
self
.
ema
=
ema
self
.
ema
=
ema
self
.
ema_decay
=
ema_decay
self
.
ema_decay
=
ema_decay
self
.
per_channel
=
per_channel
self
.
per_channel
=
per_channel
self
.
out_channels
=
out_channel
s
self
.
channel_axis
=
channel_axi
s
self
.
quant_delay
=
quant_delay
self
.
quant_delay
=
quant_delay
self
.
symmetric
=
symmetric
self
.
symmetric
=
symmetric
self
.
narrow_range
=
narrow_range
self
.
narrow_range
=
narrow_range
self
.
training
=
training
if
per_channel
:
# init tensor min and max for fake quant op
min_array
=
np
.
array
([
self
.
min_init
for
i
in
range
(
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
if
isinstance
(
min_init
,
int
):
max_array
=
np
.
array
([
self
.
max_init
for
i
in
range
(
0
,
self
.
channel_size
)]).
astype
(
np
.
float32
)
self
.
minq
=
Parameter
(
Tensor
(
min_array
),
name
=
'quant_min'
,
requires_grad
=
False
)
self
.
maxq
=
Parameter
(
Tensor
(
max_array
),
name
=
'quant_max'
,
requires_grad
=
False
)
self
.
fake_quant_train
=
P
.
FakeQuantWithMinMaxPerChannel
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema_decay
=
self
.
ema_decay
,
quant_delay
=
self
.
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
True
)
self
.
fake_quant_infer
=
P
.
FakeQuantWithMinMaxPerChannel
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema_decay
=
self
.
ema_decay
,
quant_delay
=
self
.
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
False
)
else
:
min_array
=
np
.
array
([
min_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
min_array
=
np
.
array
([
min_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_array
=
np
.
array
([
max_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_array
=
np
.
array
([
max_init
]).
reshape
(
1
).
astype
(
np
.
float32
)
self
.
minq
=
Parameter
(
Tensor
(
min_array
),
name
=
'quant_min'
,
requires_grad
=
False
)
elif
isinstance
(
min_init
,
list
):
self
.
maxq
=
Parameter
(
Tensor
(
max_array
),
name
=
'quant_max'
,
requires_grad
=
False
)
min_array
=
np
.
array
([
self
.
min_init
for
i
in
range
(
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
self
.
fake_quant_train
=
FakeQuantWithMinMaxD
(
num_bits
=
self
.
num_bits
,
max_array
=
np
.
array
([
self
.
max_init
for
i
in
range
(
ema
=
self
.
ema
,
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
ema_decay
=
self
.
ema_decay
,
self
.
minq
=
Parameter
(
Tensor
(
min_array
),
name
=
'quant_min'
,
requires_grad
=
False
)
quant_delay
=
self
.
quant_delay
,
self
.
maxq
=
Parameter
(
Tensor
(
max_array
),
name
=
'quant_max'
,
requires_grad
=
False
)
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
if
per_channel
:
training
=
True
,
quant_fun
=
partial
(
P
.
FakeQuantPerChannel
,
channel_axis
=
self
.
channel_axis
)
min_init
=
self
.
minq
,
else
:
max_init
=
self
.
maxq
)
quant_fun
=
P
.
FakeQuantPerLayer
self
.
fake_quant_infer
=
FakeQuantWithMinMaxD
(
num_bits
=
self
.
num_bits
,
self
.
fake_quant
=
quant_fun
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema
=
self
.
ema
,
ema_decay
=
self
.
ema_decay
,
ema_decay
=
ema_decay
,
quant_delay
=
self
.
quant_delay
,
quant_delay
=
quant_delay
,
symmetric
=
self
.
symmetric
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
narrow_range
=
self
.
narrow_range
,
training
=
False
,
training
=
self
.
training
)
min_init
=
self
.
minq
,
max_init
=
self
.
maxq
)
elif
context
.
get_context
(
'device_target'
)
==
"GPU"
:
self
.
fake_quant_train
=
P
.
FakeQuantWithMinMax
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema_decay
=
self
.
ema_decay
,
quant_delay
=
self
.
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
True
)
self
.
fake_quant_infer
=
P
.
FakeQuantWithMinMax
(
num_bits
=
self
.
num_bits
,
ema
=
self
.
ema
,
ema_decay
=
ema_decay
,
quant_delay
=
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
narrow_range
,
training
=
False
)
else
:
raise
ValueError
(
"Not support platform."
)
def
extend_repr
(
self
):
def
extend_repr
(
self
):
s
=
'min={}, max={}, ema={}, ema_decay={}, per_channel={}, quant_delay={}'
.
format
(
s
=
'ema={}, ema_decay={}, per_channel={}, quant_delay={}, channel_axis={}, min={}, max={}'
.
format
(
self
.
min_init
,
self
.
max_init
,
self
.
ema
,
self
.
ema_decay
,
self
.
per_channel
,
self
.
quant_delay
)
self
.
min_init
,
self
.
max_init
,
self
.
ema
,
self
.
ema_decay
,
self
.
per_channel
,
self
.
quant_delay
,
self
.
channel_axis
)
return
s
return
s
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
if
self
.
training
:
out
=
self
.
fake_quant
(
x
,
self
.
minq
,
self
.
maxq
)
out
=
self
.
fake_quant_train
(
x
,
self
.
minq
,
self
.
maxq
)
else
:
out
=
self
.
fake_quant_infer
(
x
,
self
.
minq
,
self
.
maxq
)
return
out
return
out
def
FakeQuantWithMinMax
(
**
kwargs
):
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
out
=
FakeQuantWithMinMaxAscend
(
**
kwargs
)
if
context
.
get_context
(
'device_target'
)
==
"GPU"
:
out
=
FakeQuantWithMinMaxGPU
(
**
kwargs
)
else
:
raise
ValueError
(
"Not support platform or channel mode."
)
return
out
class
Conv2dBatchNormQuant
(
Cell
):
class
Conv2dBatchNormQuant
(
Cell
):
r
"""
r
"""
2D convolution with BatchNormal op folded layer.
2D convolution with BatchNormal op folded layer.
...
@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -420,7 +393,6 @@ class Conv2dBatchNormQuant(Cell):
self
.
per_channel
=
per_channel
self
.
per_channel
=
per_channel
self
.
symmetric
=
symmetric
self
.
symmetric
=
symmetric
self
.
narrow_range
=
narrow_range
self
.
narrow_range
=
narrow_range
self
.
channel_axis
=
int
(
group
>
1
)
self
.
is_gpu
=
context
.
get_context
(
'device_target'
)
==
"GPU"
self
.
is_gpu
=
context
.
get_context
(
'device_target'
)
==
"GPU"
# initialize convolution op and Parameter
# initialize convolution op and Parameter
...
@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -435,6 +407,7 @@ class Conv2dBatchNormQuant(Cell):
dilation
=
self
.
dilation
)
dilation
=
self
.
dilation
)
if
weight_init
is
None
:
if
weight_init
is
None
:
weight_init
=
initializer
(
'normal'
,
[
1
,
in_channels
,
*
self
.
kernel_size
])
weight_init
=
initializer
(
'normal'
,
[
1
,
in_channels
,
*
self
.
kernel_size
])
channel_axis
=
1
else
:
else
:
self
.
conv
=
P
.
Conv2D
(
out_channel
=
out_channels
,
self
.
conv
=
P
.
Conv2D
(
out_channel
=
out_channels
,
kernel_size
=
self
.
kernel_size
,
kernel_size
=
self
.
kernel_size
,
...
@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -445,6 +418,7 @@ class Conv2dBatchNormQuant(Cell):
group
=
group
)
group
=
group
)
if
weight_init
is
None
:
if
weight_init
is
None
:
weight_init
=
initializer
(
'normal'
,
[
out_channels
,
in_channels
//
group
,
*
self
.
kernel_size
])
weight_init
=
initializer
(
'normal'
,
[
out_channels
,
in_channels
//
group
,
*
self
.
kernel_size
])
channel_axis
=
0
self
.
weight
=
Parameter
(
weight_init
,
name
=
'weight'
)
self
.
weight
=
Parameter
(
weight_init
,
name
=
'weight'
)
# initialize batchnorm Parameter
# initialize batchnorm Parameter
...
@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -472,7 +446,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric
=
symmetric
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
)
narrow_range
=
narrow_range
)
self
.
batchnorm_fold
=
BatchNormFoldCell
(
epsilon
=
eps
,
momentum
=
momentum
,
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold
=
BatchNormFoldCell
(
epsilon
=
eps
,
momentum
=
momentum
,
freeze_bn
=
freeze_bn
)
self
.
correct_mul
=
P
.
CorrectionMul
(
self
.
channel_axis
)
self
.
correct_mul
=
P
.
CorrectionMul
(
channel_axis
)
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
batchnorm_fold2_train
=
P
.
BatchNormFold2_D
(
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold2_train
=
P
.
BatchNormFold2_D
(
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold2_infer
=
P
.
BatchNormFold2_D
(
freeze_bn
=
0
)
self
.
batchnorm_fold2_infer
=
P
.
BatchNormFold2_D
(
freeze_bn
=
0
)
...
@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -520,7 +494,7 @@ class Conv2dBatchNormQuant(Cell):
out
=
self
.
batchnorm_fold2_train
(
out
,
self
.
beta
,
self
.
gamma
,
batch_std
,
batch_mean
,
running_std
)
out
=
self
.
batchnorm_fold2_train
(
out
,
self
.
beta
,
self
.
gamma
,
batch_std
,
batch_mean
,
running_std
)
F
.
control_depend
(
out
,
self
.
assignadd
(
self
.
step
,
self
.
one
))
F
.
control_depend
(
out
,
self
.
assignadd
(
self
.
step
,
self
.
one
))
else
:
else
:
out
=
self
.
batchnorm_fold2_infer
(
out
,
self
.
beta
,
self
.
gamma
,
batch_std
,
batch
_mean
,
running_std
)
out
=
self
.
batchnorm_fold2_infer
(
out
,
self
.
beta
,
self
.
gamma
,
running_std
,
running
_mean
,
running_std
)
return
out
return
out
...
...
mindspore/ops/_grad/grad_quant_ops.py
浏览文件 @
b7db3e9a
...
@@ -20,10 +20,11 @@ from .grad_base import bprop_getters
...
@@ -20,10 +20,11 @@ from .grad_base import bprop_getters
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
from
..composite.multitype_ops.zeros_like_impl
import
zeros_like
@
bprop_getters
.
register
(
P
.
FakeQuant
WithMinMax
)
@
bprop_getters
.
register
(
P
.
FakeQuant
PerLayer
)
def
get_bprop_fakequant_with_minmax
(
self
):
def
get_bprop_fakequant_with_minmax
(
self
):
"""Generate bprop for FakeQuantWithMinMax for GPU and Ascend"""
"""Generate bprop for FakeQuantPerLayer for GPU and Ascend"""
op
=
P
.
FakeQuantWithMinMaxGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
)
op
=
P
.
FakeQuantPerLayerGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
)
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
dx
=
op
(
dout
,
x
,
x_min
,
x_max
)
dx
=
op
(
dout
,
x
,
x_min
,
x_max
)
...
@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
...
@@ -32,10 +33,14 @@ def get_bprop_fakequant_with_minmax(self):
return
bprop
return
bprop
@
bprop_getters
.
register
(
P
.
FakeQuant
WithMinMax
PerChannel
)
@
bprop_getters
.
register
(
P
.
FakeQuantPerChannel
)
def
get_bprop_fakequant_with_minmax_perchannel
(
self
):
def
get_bprop_fakequant_with_minmax_perchannel
(
self
):
"""Generate bprop for FakeQuantWithMinMaxPerChannel for GPU"""
"""Generate bprop for FakeQuantPerChannel"""
op
=
P
.
FakeQuantWithMinMaxPerChannelGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
)
op
=
P
.
FakeQuantPerChannelGrad
(
num_bits
=
self
.
num_bits
,
quant_delay
=
self
.
quant_delay
,
symmetric
=
self
.
symmetric
,
narrow_range
=
self
.
symmetric
,
channel_axis
=
self
.
channel_axis
)
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
dx
=
op
(
dout
,
x
,
x_min
,
x_max
)
dx
=
op
(
dout
,
x
,
x_min
,
x_max
)
...
@@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self):
...
@@ -77,7 +82,7 @@ def get_bprop_batchnorm_fold2(self):
d_batch_std
,
d_batch_mean
,
d_beta
,
d_gamma
,
d_x
=
op_f
(
dout
,
x
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
d_batch_std
,
d_batch_mean
,
d_beta
,
d_gamma
,
d_x
=
op_f
(
dout
,
x
,
gamma
,
batch_std
,
batch_mean
,
running_std
,
running_mean
,
global_step
)
running_mean
,
global_step
)
return
d_x
,
d_beta
,
d_gamma
,
d_batch_std
,
d_batch_mean
,
zeros_like
(
running_std
),
zeros_like
(
running_mean
),
\
return
d_x
,
d_beta
,
d_gamma
,
d_batch_std
,
d_batch_mean
,
zeros_like
(
running_std
),
zeros_like
(
running_mean
),
\
zeros_like
(
global_step
)
zeros_like
(
global_step
)
return
bprop
return
bprop
...
@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
...
@@ -117,9 +122,19 @@ def get_bprop_batchnorm_fold2_(self):
return
bprop
return
bprop
@
bprop_getters
.
register
(
P
.
FakeQuantWithMinMaxUpdate
)
@
bprop_getters
.
register
(
P
.
FakeQuantMinMaxPerLayerUpdate
)
def
get_bprop_fakequant_with_minmax_update
(
self
):
def
get_bprop_fakequant_with_minmax_per_layer_update
(
self
):
"""Generate bprop for FakeQuantWithMinMaxUpdate for Ascend"""
"""Generate bprop for FakeQuantMinMaxPerLayerUpdate for Ascend"""
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
return
bprop
@
bprop_getters
.
register
(
P
.
FakeQuantMinMaxPerChannelUpdate
)
def
get_bprop_fakequant_with_minmax_per_channel_update
(
self
):
"""Generate bprop for FakeQuantMinMaxPerChannelUpdate for Ascend"""
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
def
bprop
(
x
,
x_min
,
x_max
,
out
,
dout
):
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
return
zeros_like
(
x
),
zeros_like
(
x_min
),
zeros_like
(
x_max
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
0 → 100644
浏览文件 @
b7db3e9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantMinMaxPerChannelUpdate op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_min_max_per_channel_update_op_info
=
TBERegOp
(
"FakeQuantMinMaxPerChannelUpdate"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_min_max_per_channel_update.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_min_max_per_channel_update"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"min_up"
,
True
,
"required"
,
"all"
)
\
.
output
(
1
,
"max_up"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
fake_quant_min_max_per_channel_update_op_info
)
def
_fake_quant_min_max_per_channel_update_tbe
():
"""FakeQuantPerChannelUpdate TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_min_max_per_channel_update"
)
def
fake_quant_min_max_per_channel_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update"
):
"""FakeQuantPerChannelUpdate compute"""
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
if
not
ema
:
ema_decay
=
0.0
if
training
:
# CalMinMax
axis
=
[
0
,
2
,
3
]
x_min
=
te
.
lang
.
cce
.
reduce_min
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_min_max_per_channel_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
channel_axis
,
kernel_name
=
"fake_quant_min_max_per_channel_update"
):
"""FakeQuantPerLayer op"""
x_shape
=
x
.
get
(
"ori_shape"
)
x_format
=
x
.
get
(
"format"
)
x_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
"float16"
]
x_dtype
=
x_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
shape_c
=
[
min_val
.
get
(
"shape"
)[
1
],
min_val
.
get
(
"shape"
)[
-
1
]]
input_data
=
tvm
.
placeholder
(
x
.
get
(
"shape"
),
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res_list
=
fake_quant_min_max_per_channel_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
tensor_list
=
[
input_data
,
min_data
,
max_data
]
+
list
(
res_list
)
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_
with_min_max
_update.py
→
mindspore/ops/_op_impl/_custom_op/fake_quant_
minmax_perlayer
_update.py
浏览文件 @
b7db3e9a
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""FakeQuant
WithMinMax
Update op"""
"""FakeQuant
MinMaxPerLayer
Update op"""
from
functools
import
reduce
as
functools_reduce
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
import
te.lang.cce
from
te
import
tvm
from
te
import
tvm
...
@@ -23,12 +23,12 @@ from topi.cce import util
...
@@ -23,12 +23,12 @@ from topi.cce import util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_
update_op_info
=
TBERegOp
(
"FakeQuantWithMinMax
Update"
)
\
fake_quant_
minmax_update_op_info
=
TBERegOp
(
"FakeQuantMinMaxPerLayer
Update"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_
with_min_
max_update.so"
)
\
.
binfile_name
(
"fake_quant_
min
max_update.so"
)
\
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_
with_min_
max_update"
)
\
.
kernel_name
(
"fake_quant_
min
max_update"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
...
@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
...
@@ -36,7 +36,6 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
...
@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
...
@@ -47,16 +46,16 @@ fake_quant_update_op_info = TBERegOp("FakeQuantWithMinMaxUpdate") \
.
get_op_info
()
.
get_op_info
()
@
op_info_register
(
fake_quant_update_op_info
)
@
op_info_register
(
fake_quant_
minmax_
update_op_info
)
def
_fake_quant_update_tbe
():
def
_fake_quant_
minmax_
update_tbe
():
"""
_FakeQuantWithMinMax
Update TBE register"""
"""
FakeQuantMinMaxPerLayer
Update TBE register"""
return
return
@
fusion_manager
.
register
(
"fake_quant_
with_min_
max_update"
)
@
fusion_manager
.
register
(
"fake_quant_
min
max_update"
)
def
fake_quant_
with_min_
max_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
def
fake_quant_
min
max_update_compute
(
x
,
min_val
,
max_val
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
=
"fake_quant
_update"
):
kernel_name
=
"fake_quant_minmax
_update"
):
"""FakeQuant
WithMinMax
Update compute"""
"""FakeQuant
MinMaxPerLayer
Update compute"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
min_val
=
te
.
lang
.
cce
.
broadcast
(
min_val
,
shape_min
,
x
.
dtype
)
min_val
=
te
.
lang
.
cce
.
broadcast
(
min_val
,
shape_min
,
x
.
dtype
)
...
@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
...
@@ -70,19 +69,21 @@ def fake_quant_with_min_max_update_compute(x, min_val, max_val, ema, ema_decay,
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_max
=
te
.
lang
.
cce
.
reduce_max
(
x
,
axis
=
axis
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_min
=
te
.
lang
.
cce
.
broadcast
(
x_min
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
x_max
=
te
.
lang
.
cce
.
broadcast
(
x_max
,
shape_min
)
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_min
,
(
1
-
ema_decay
)))
max_val
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmuls
(
max_val
,
ema_decay
),
te
.
lang
.
cce
.
vmuls
(
x_max
,
(
1
-
ema_decay
)))
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
min_val
=
te
.
lang
.
cce
.
vmins
(
min_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
max_val
=
te
.
lang
.
cce
.
vmaxs
(
max_val
,
0
)
return
[
min_val
,
max_val
]
return
[
min_val
,
max_val
]
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
str
)
def
fake_quant_
with_min_
max_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
def
fake_quant_
min
max_update
(
x
,
min_val
,
max_val
,
min_up
,
max_up
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
quant_delay
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
kernel_name
=
"fake_quant
_update"
):
kernel_name
=
"fake_quant_minmax
_update"
):
"""FakeQuant
WithMinMax
op"""
"""FakeQuant
PerLayer
op"""
input_shape
=
x
.
get
(
"shape"
)
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
...
@@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
...
@@ -123,8 +124,8 @@ def fake_quant_with_min_max_update(x, min_val, max_val, min_up, max_up,
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res_list
=
fake_quant_
with_min_
max_update_compute
(
input_data
,
min_data
,
max_data
,
res_list
=
fake_quant_
min
max_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
)
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
kernel_name
)
with
tvm
.
target
.
cce
():
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
0 → 100644
浏览文件 @
b7db3e9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannel op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_perchannel_op_info
=
TBERegOp
(
"FakeQuantPerChannel"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_perchannel.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_perchannel"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"y"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
@
op_info_register
(
fake_quant_perchannel_op_info
)
def
_fake_quant_perchannel_tbe
():
"""FakeQuantPerChannel TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_perchannel"
)
def
fake_quant_perchannel_compute
(
x
,
min_val
,
max_val
,
y
,
quant_min
,
quant_max
,
kernel_name
=
"fake_quant_perchannel"
):
"""FakeQuantPerChannel"""
x_shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
minmax_shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
tvm
.
const
(
quant_min
,
x
.
dtype
)
quant_max
=
tvm
.
const
(
quant_max
,
x
.
dtype
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
minmax_shape
,
x
.
dtype
)
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
minmax_shape
,
x
.
dtype
)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
nudge_zp_
=
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
))
nudge_zp
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_zp_
,
0.5
))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
# FakeQuant
nudge_min_b
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
x_shape
)
nudge_max_b
=
te
.
lang
.
cce
.
broadcast
(
nudge_max
,
x_shape
)
scale_b
=
te
.
lang
.
cce
.
broadcast
(
scale
,
x_shape
)
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max_b
,
te
.
lang
.
cce
.
vmax
(
nudge_min_b
,
x
))
nudge_input_
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min_b
),
scale_b
)
nudge_input
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_input_
,
0.5
))
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale_b
),
nudge_min_b
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_perchannel
(
x
,
min_val
,
max_val
,
y
,
symmetric
,
narrow_range
,
num_bits
,
channel_axis
,
kernel_name
=
"fake_quant_perchannel"
):
"""FakeQuantPerChannel"""
x_shape
=
x
.
get
(
"shape"
)
x_format
=
x
.
get
(
"format"
)
x_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
"float16"
]
x_dtype
=
x_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res
=
fake_quant_perchannel_compute
(
input_data
,
min_data
,
max_data
,
y
,
quant_min
,
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
input_data
,
min_data
,
max_data
,
res
]
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
0 → 100644
浏览文件 @
b7db3e9a
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FakeQuantPerChannelGrad op"""
import
te.lang.cce
from
te
import
tvm
from
te.platform.fusion_manager
import
fusion_manager
from
topi
import
generic
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
SHAPE_SIZE_LIMIT
=
2147483648
D_TYPE
=
'float32'
fake_quant_perchannel_grad_op_info
=
TBERegOp
(
"FakeQuantPerChannelGrad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_perchannel_grad.so"
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_perchannel_grad"
)
\
.
partial_flag
(
True
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"channel_axis"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
3
,
"max"
,
None
,
"required"
,
None
)
\
.
output
(
0
,
"dx"
,
True
,
"required"
,
"all"
)
\
.
dtype_format
(
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
,
DataType
.
F16_Default
)
\
.
dtype_format
(
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
,
DataType
.
F16_5HD
)
\
.
dtype_format
(
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
,
DataType
.
F32_Default
)
\
.
dtype_format
(
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
,
DataType
.
F32_5HD
)
\
.
get_op_info
()
def
_less_compare_float32
(
data_x
,
data_y
):
"""_less_compare_float32 compute"""
input_shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
data_x
.
shape
)
min_value
=
tvm
.
const
(
2
**
(
-
126
),
dtype
=
D_TYPE
)
max_value
=
tvm
.
const
(
2
**
62
,
dtype
=
D_TYPE
)
factor_value
=
tvm
.
const
(
2
**
2
,
dtype
=
D_TYPE
)
data_zero
=
te
.
lang
.
cce
.
broadcast
(
tvm
.
const
(
0
,
dtype
=
D_TYPE
),
input_shape
,
D_TYPE
)
min_value_tensor
=
te
.
lang
.
cce
.
vadds
(
data_zero
,
min_value
)
res_sub
=
te
.
lang
.
cce
.
vsub
(
data_y
,
data_x
)
res_min
=
te
.
lang
.
cce
.
vmin
(
res_sub
,
min_value_tensor
)
res_max
=
te
.
lang
.
cce
.
vmax
(
res_min
,
data_zero
)
res_max_mul
=
te
.
lang
.
cce
.
vmuls
(
res_max
,
max_value
)
res_max_mul_max
=
te
.
lang
.
cce
.
vmuls
(
res_max_mul
,
max_value
)
res
=
te
.
lang
.
cce
.
vmuls
(
res_max_mul_max
,
factor_value
)
return
res
@
op_info_register
(
fake_quant_perchannel_grad_op_info
)
def
_fake_quant_perchannel_grad_tbe
():
"""FakeQuantPerChannelGrad TBE register"""
return
@
fusion_manager
.
register
(
"fake_quant_perchannel_grad"
)
def
fake_quant_perchannel_grad_compute
(
dout
,
x
,
min_val
,
max_val
,
quant_min
,
quant_max
,
kernel_name
=
"fake_quant_perchannel_grad"
):
"""FakeQuantPerChannelGrad"""
x_shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
minmax_shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
tvm
.
const
(
quant_min
,
x
.
dtype
)
quant_max
=
tvm
.
const
(
quant_max
,
x
.
dtype
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
minmax_shape
,
x
.
dtype
)
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
minmax_shape
,
x
.
dtype
)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
nudge_zp_
=
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
))
nudge_zp
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_zp_
,
0.5
))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
# FakeQuant Grad
nudge_min_b
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
x_shape
)
nudge_max_b
=
te
.
lang
.
cce
.
broadcast
(
nudge_max
,
x_shape
)
bool_over_min
=
_less_compare_float32
(
nudge_min_b
,
x
)
bool_less_max
=
_less_compare_float32
(
x
,
nudge_max_b
)
bool_between
=
te
.
lang
.
cce
.
vmul
(
bool_over_min
,
bool_less_max
)
res
=
te
.
lang
.
cce
.
vmul
(
dout
,
bool_between
)
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
bool
,
bool
,
int
,
int
,
str
)
def
fake_quant_perchannel_grad
(
dout
,
x
,
min_val
,
max_val
,
dx
,
symmetric
,
narrow_range
,
num_bits
,
channel_axis
,
kernel_name
=
"fake_quant_perchannel_grad"
):
"""FakeQuantPerChannelGrad"""
x_shape
=
x
.
get
(
"shape"
)
x_format
=
x
.
get
(
"format"
)
x_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
check_list
=
[
"float32"
,
"float16"
]
x_dtype
=
x_dtype
.
lower
()
min_dtype
=
min_dtype
.
lower
()
max_dtype
=
max_dtype
.
lower
()
util
.
check_dtype_rule
(
x_dtype
,
check_list
)
util
.
check_dtype_rule
(
min_dtype
,
check_list
)
util
.
check_dtype_rule
(
max_dtype
,
check_list
)
if
symmetric
:
quant_min
=
0
-
2
**
(
num_bits
-
1
)
quant_max
=
2
**
(
num_bits
-
1
)
-
1
else
:
quant_min
=
0
quant_max
=
2
**
num_bits
-
1
if
narrow_range
:
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
dout_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res
=
fake_quant_perchannel_grad_compute
(
dout_data
,
input_data
,
min_data
,
max_data
,
quant_min
,
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
tensor_list
=
[
dout_data
,
input_data
,
min_data
,
max_data
,
res
]
config
=
{
"print_ir"
:
False
,
"name"
:
kernel_name
,
"tensor_list"
:
tensor_list
}
te
.
lang
.
cce
.
cce_build_code
(
sch
,
config
)
mindspore/ops/_op_impl/_custom_op/fake_quant_
with_min_max
.py
→
mindspore/ops/_op_impl/_custom_op/fake_quant_
perlayer
.py
浏览文件 @
b7db3e9a
...
@@ -13,8 +13,7 @@
...
@@ -13,8 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""FakeQuantWithMinMax op"""
"""FakeQuantPerLayer op"""
from
functools
import
reduce
as
functools_reduce
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
import
te.lang.cce
from
te
import
tvm
from
te
import
tvm
...
@@ -23,20 +22,16 @@ from topi import generic
...
@@ -23,20 +22,16 @@ from topi import generic
from
topi.cce
import
util
from
topi.cce
import
util
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
from
mindspore.ops.op_info_register
import
op_info_register
,
TBERegOp
,
DataType
fake_quant_
op_info
=
TBERegOp
(
"FakeQuantWithMinMax
"
)
\
fake_quant_
per_layer_op_info
=
TBERegOp
(
"FakeQuantPerLayer
"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
fusion_type
(
"ELEMWISE"
)
\
.
async_flag
(
False
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_
with_min_max_vars_ema
.so"
)
\
.
binfile_name
(
"fake_quant_
per_layer
.so"
)
\
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_
with_min_max_vars_ema
"
)
\
.
kernel_name
(
"fake_quant_
per_layer
"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
attr
(
"ema"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"ema_decay"
,
"optional"
,
"float"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"training"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"x"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
1
,
"min"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
.
input
(
2
,
"max"
,
None
,
"required"
,
None
)
\
...
@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
...
@@ -49,15 +44,15 @@ fake_quant_op_info = TBERegOp("FakeQuantWithMinMax") \
@
op_info_register
(
fake_quant_op_info
)
@
op_info_register
(
fake_quant_op_info
)
def
_fake_quant_tbe
():
def
_fake_quant_
per_layer_
tbe
():
"""FakeQuant
WithMinMax
TBE register"""
"""FakeQuant
PerLayer
TBE register"""
return
return
@
fusion_manager
.
register
(
"fake_quant_
with_min_max_vars_ema
"
)
@
fusion_manager
.
register
(
"fake_quant_
per_layer
"
)
def
fake_quant_
with_min_max_vars_ema
_compute
(
x
,
min_val
,
max_val
,
y
,
quant_min
,
quant_max
,
def
fake_quant_
per_layer
_compute
(
x
,
min_val
,
max_val
,
y
,
quant_min
,
quant_max
,
kernel_name
=
"correction_mul
"
):
kernel_name
=
"fake_quant_per_layer
"
):
"""FakeQuant
WithMinMax
"""
"""FakeQuant
PerLayer
"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
shape_min
,
x
.
dtype
)
quant_min
=
te
.
lang
.
cce
.
broadcast
(
quant_min
,
shape_min
,
x
.
dtype
)
...
@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
...
@@ -66,10 +61,13 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
max_val
=
te
.
lang
.
cce
.
broadcast
(
max_val
,
shape_min
,
x
.
dtype
)
# CalNudge(NudgeMinMax)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
# Nudge zero point
nudge_zp
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
)))
nudge_zp_
=
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
))
nudge_zp
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_zp_
,
0.5
))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
...
@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
...
@@ -80,17 +78,19 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
# FakeQuant
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max
,
te
.
lang
.
cce
.
vmax
(
nudge_min
,
x
))
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max
,
te
.
lang
.
cce
.
vmax
(
nudge_min
,
x
))
nudge_input
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min
),
scale
))
nudge_input_
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min
),
scale
)
nudge_input
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_input_
,
0.5
))
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale
),
nudge_min
)
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale
),
nudge_min
)
return
res
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
bool
,
float
,
bool
,
bool
,
bool
,
int
,
int
,
str
)
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
bool
,
bool
,
int
,
str
)
def
fake_quant_
with_min_max_vars_ema
(
x
,
min_val
,
max_val
,
y
,
def
fake_quant_
per_layer
(
x
,
min_val
,
max_val
,
y
,
ema
,
ema_decay
,
symmetric
,
narrow_range
,
training
,
num_bits
,
quant_delay
,
symmetric
,
narrow_range
,
num_bits
,
kernel_name
=
"fake_quant
"
):
kernel_name
=
"fake_quant_per_layer
"
):
"""FakeQuant
WithMinMax
"""
"""FakeQuant
PerLayer
"""
input_shape
=
x
.
get
(
"shape"
)
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
...
@@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
...
@@ -131,8 +131,8 @@ def fake_quant_with_min_max_vars_ema(x, min_val, max_val, y,
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res
=
fake_quant_
with_min_max_vars_ema
_compute
(
input_data
,
min_data
,
max_data
,
y
,
res
=
fake_quant_
per_layer
_compute
(
input_data
,
min_data
,
max_data
,
y
,
quant_min
,
quant_max
,
kernel_name
)
quant_min
,
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
sch
=
generic
.
auto_schedule
(
res
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_
with_min_max
_grad.py
→
mindspore/ops/_op_impl/_custom_op/fake_quant_
perlayer
_grad.py
浏览文件 @
b7db3e9a
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""FakeQuant
WithMinMax
Grad op"""
"""FakeQuant
PerLayer
Grad op"""
from
functools
import
reduce
as
functools_reduce
from
functools
import
reduce
as
functools_reduce
import
te.lang.cce
import
te.lang.cce
...
@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
...
@@ -26,15 +26,14 @@ from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
SHAPE_SIZE_LIMIT
=
2147483648
SHAPE_SIZE_LIMIT
=
2147483648
D_TYPE
=
'float32'
D_TYPE
=
'float32'
fake_quant_
grad_op_info
=
TBERegOp
(
"FakeQuantWithMinMax
Grad"
)
\
fake_quant_
per_layer_grad_op_info
=
TBERegOp
(
"FakeQuantPerLayer
Grad"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
fusion_type
(
"OPAQUE"
)
\
.
async_flag
(
False
)
\
.
async_flag
(
False
)
\
.
binfile_name
(
"fake_quant_
with_min_max
_grad.so"
)
\
.
binfile_name
(
"fake_quant_
per_layer
_grad.so"
)
\
.
compute_cost
(
10
)
\
.
compute_cost
(
10
)
\
.
kernel_name
(
"fake_quant_
with_min_max
_grad"
)
\
.
kernel_name
(
"fake_quant_
per_layer
_grad"
)
\
.
partial_flag
(
True
)
\
.
partial_flag
(
True
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"num_bits"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"quant_delay"
,
"optional"
,
"int"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"symmetric"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
attr
(
"narrow_range"
,
"optional"
,
"bool"
,
"all"
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
.
input
(
0
,
"dout"
,
None
,
"required"
,
None
)
\
...
@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
...
@@ -57,7 +56,8 @@ def _less_compare_float32(data_x, data_y):
min_value
=
tvm
.
const
(
2
**
(
-
126
),
dtype
=
D_TYPE
)
min_value
=
tvm
.
const
(
2
**
(
-
126
),
dtype
=
D_TYPE
)
max_value
=
tvm
.
const
(
2
**
62
,
dtype
=
D_TYPE
)
max_value
=
tvm
.
const
(
2
**
62
,
dtype
=
D_TYPE
)
factor_value
=
tvm
.
const
(
2
**
2
,
dtype
=
D_TYPE
)
factor_value
=
tvm
.
const
(
2
**
2
,
dtype
=
D_TYPE
)
data_zero
=
te
.
lang
.
cce
.
broadcast
(
tvm
.
const
(
0
,
dtype
=
D_TYPE
),
shape_inputs
,
D_TYPE
)
data_zero
=
te
.
lang
.
cce
.
broadcast
(
tvm
.
const
(
0
,
dtype
=
D_TYPE
),
shape_inputs
,
D_TYPE
)
min_value_tensor
=
te
.
lang
.
cce
.
vadds
(
data_zero
,
min_value
)
min_value_tensor
=
te
.
lang
.
cce
.
vadds
(
data_zero
,
min_value
)
res_sub
=
te
.
lang
.
cce
.
vsub
(
data_y
,
data_x
)
res_sub
=
te
.
lang
.
cce
.
vsub
(
data_y
,
data_x
)
...
@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
...
@@ -71,16 +71,16 @@ def _less_compare_float32(data_x, data_y):
return
res
return
res
@
op_info_register
(
fake_quant_grad_op_info
)
@
op_info_register
(
fake_quant_
per_layer_
grad_op_info
)
def
_fake_quant_grad_tbe
():
def
_fake_quant_
per_layer_
grad_tbe
():
"""FakeQuant
WithMinMax
Grad TBE register"""
"""FakeQuant
PerLayer
Grad TBE register"""
return
return
@
fusion_manager
.
register
(
"fake_quant_
with_min_max
_grad"
)
@
fusion_manager
.
register
(
"fake_quant_
per_layer
_grad"
)
def
fake_quant_
with_min_max
_grad_compute
(
dout
,
x
,
min_val
,
max_val
,
quant_min
,
quant_max
,
def
fake_quant_
per_layer
_grad_compute
(
dout
,
x
,
min_val
,
max_val
,
quant_min
,
quant_max
,
kernel_name
=
"fake_quant_with_min_max
_grad"
):
kernel_name
=
"fake_quant_per_layer
_grad"
):
"""FakeQuant
WithMinMax
Grad"""
"""FakeQuant
PerLayer
Grad"""
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
x
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
shape_min
=
te
.
lang
.
cce
.
util
.
shape_to_list
(
min_val
.
shape
)
quant_min
=
tvm
.
const
(
quant_min
,
x
.
dtype
)
quant_min
=
tvm
.
const
(
quant_min
,
x
.
dtype
)
...
@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
...
@@ -89,10 +89,13 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
shape_min
)
quant_max
=
te
.
lang
.
cce
.
broadcast
(
quant_max
,
shape_min
)
# CalNudge(NudgeMinMax)
# CalNudge(NudgeMinMax)
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
scale
=
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
max_val
,
min_val
),
te
.
lang
.
cce
.
vsub
(
quant_max
,
quant_min
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
zp_from_min
=
te
.
lang
.
cce
.
vsub
(
quant_min
,
te
.
lang
.
cce
.
vdiv
(
min_val
,
scale
))
# Nudge zero point
# Nudge zero point
nudge_zp
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
)))
nudge_zp_
=
te
.
lang
.
cce
.
vmin
(
quant_max
,
te
.
lang
.
cce
.
vmax
(
quant_min
,
zp_from_min
))
nudge_zp
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
nudge_zp_
,
0.5
))
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_min
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_min
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
nudge_max
=
te
.
lang
.
cce
.
vmul
(
te
.
lang
.
cce
.
vsub
(
quant_max
,
nudge_zp
),
scale
)
nudge_min
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
shape
)
nudge_min
=
te
.
lang
.
cce
.
broadcast
(
nudge_min
,
shape
)
...
@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
...
@@ -106,11 +109,11 @@ def fake_quant_with_min_max_grad_compute(dout, x, min_val, max_val, quant_min, q
return
res
return
res
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
int
,
int
,
bool
,
bool
,
str
)
@
util
.
check_input_type
(
dict
,
dict
,
dict
,
dict
,
dict
,
int
,
bool
,
bool
,
str
)
def
fake_quant_
with_min_max
_grad
(
dout
,
x
,
min_val
,
max_val
,
dx
,
def
fake_quant_
per_layer
_grad
(
dout
,
x
,
min_val
,
max_val
,
dx
,
num_bits
,
quant_delay
,
symmetric
,
narrow_range
,
num_bits
,
symmetric
,
narrow_range
,
kernel_name
=
"fake_quant_with_min_max
_grad"
):
kernel_name
=
"fake_quant_per_layer
_grad"
):
"""FakeQuant
WithMinMax
Grad"""
"""FakeQuant
PerLayer
Grad"""
input_shape
=
x
.
get
(
"shape"
)
input_shape
=
x
.
get
(
"shape"
)
input_dtype
=
x
.
get
(
"dtype"
)
input_dtype
=
x
.
get
(
"dtype"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
min_shape
=
min_val
.
get
(
"ori_shape"
)
...
@@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
...
@@ -152,8 +155,8 @@ def fake_quant_with_min_max_grad(dout, x, min_val, max_val, dx,
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
input_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
min_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"min_data"
,
dtype
=
min_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
max_data
=
tvm
.
placeholder
(
shape_min
,
name
=
"max_data"
,
dtype
=
max_dtype
)
res
=
fake_quant_
with_min_max
_grad_compute
(
dout_data
,
input_data
,
min_data
,
max_data
,
quant_min
,
res
=
fake_quant_
per_layer
_grad_compute
(
dout_data
,
input_data
,
min_data
,
max_data
,
quant_min
,
quant_max
,
kernel_name
)
quant_max
,
kernel_name
)
with
tvm
.
target
.
cce
():
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res
)
sch
=
generic
.
auto_schedule
(
res
)
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
b7db3e9a
...
@@ -20,10 +20,12 @@ from ..._checkparam import Rel
...
@@ -20,10 +20,12 @@ from ..._checkparam import Rel
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
from
..primitive
import
PrimitiveWithInfer
,
prim_attr_register
from
...common
import
dtype
as
mstype
from
...common
import
dtype
as
mstype
__all__
=
[
"FakeQuantWithMinMax"
,
__all__
=
[
"FakeQuantPerLayer"
,
"FakeQuantWithMinMaxGrad"
,
"FakeQuantPerLayerGrad"
,
"FakeQuantWithMinMaxPerChannel"
,
"FakeQuantPerChannel"
,
"FakeQuantWithMinMaxPerChannelGrad"
,
"FakeQuantPerChannelGrad"
,
"FakeQuantMinMaxPerLayerUpdate"
,
"FakeQuantMinMaxPerChannelUpdate"
,
"BatchNormFold"
,
"BatchNormFold"
,
"BatchNormFoldGrad"
,
"BatchNormFoldGrad"
,
"CorrectionMul"
,
"CorrectionMul"
,
...
@@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax",
...
@@ -36,11 +38,10 @@ __all__ = ["FakeQuantWithMinMax",
"BatchNormFold2_D"
,
"BatchNormFold2_D"
,
"BatchNormFold2GradD"
,
"BatchNormFold2GradD"
,
"BatchNormFold2GradReduce"
,
"BatchNormFold2GradReduce"
,
"FakeQuantWithMinMaxUpdate"
,
]
]
class
FakeQuant
WithMinMax
(
PrimitiveWithInfer
):
class
FakeQuant
PerLayer
(
PrimitiveWithInfer
):
r
"""
r
"""
Simulate the quantize and dequantize operations in training time.
Simulate the quantize and dequantize operations in training time.
...
@@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
...
@@ -67,49 +68,67 @@ class FakeQuantWithMinMax(PrimitiveWithInfer):
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> input_tensor = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> min_tensor = Tensor(np.array([-6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> max_tensor = Tensor(np.array([6]), mstype.float32)
>>> output_tensor = P.FakeQuant
WithMinMax
(num_bits=8)(input_tensor, min_tensor, max_tensor)
>>> output_tensor = P.FakeQuant
PerLayer
(num_bits=8)(input_tensor, min_tensor, max_tensor)
"""
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
):
training
=
True
):
"""init FakeQuant
WithMinMax
OP"""
"""init FakeQuant
PerLayer
OP"""
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
outputs
=
[
'out'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
validator
.
check_tensor_type_same
({
"max"
:
max_type
},
valid_types
,
self
.
name
)
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
x_type
return
x_type
class
FakeQuant
WithMinMax
Grad
(
PrimitiveWithInfer
):
class
FakeQuant
PerLayer
Grad
(
PrimitiveWithInfer
):
r
"""
r
"""
Performs grad of FakeQuant
WithMinMax
operation.
Performs grad of FakeQuant
PerLayerGrad
operation.
Examples:
Examples:
>>> fake_min_max_grad = P.FakeQuant
WithMinMax
Grad()
>>> fake_min_max_grad = P.FakeQuant
PerLayer
Grad()
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> dout = Tensor(np.array([[-2.3, 1.2], [5.7, 0.2]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> input_x = Tensor(np.array([[18, -23], [0.2, 6]]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32)
>>> _min = Tensor(np.array([-4]), mindspore.float32)
...
@@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
...
@@ -119,32 +138,48 @@ class FakeQuantWithMinMaxGrad(PrimitiveWithInfer):
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
):
def
__init__
(
self
,
num_bits
=
8
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
):
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
self
.
init_prim_io_names
(
inputs
=
[
'dout'
,
'x'
,
'min'
,
'max'
],
outputs
=
[
'dx'
])
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'dout'
,
'x'
,
'min'
,
'max'
],
outputs
=
[
'dx'
])
def
infer_shape
(
self
,
dout_shape
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
dout_shape
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check
(
"dout shape"
,
dout_shape
,
"x shape"
,
x_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"dout shape"
,
dout_shape
,
"x shape"
,
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
x_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
dout_shape
return
dout_shape
def
infer_dtype
(
self
,
dout_type
,
x_type
,
min_type
,
max_type
):
def
infer_dtype
(
self
,
dout_type
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"dout"
:
dout_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"dout"
:
dout_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
validator
.
check_tensor_type_same
({
"max"
:
max_type
},
valid_types
,
self
.
name
)
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
dout_type
return
dout_type
class
FakeQuant
WithMinMax
PerChannel
(
PrimitiveWithInfer
):
class
FakeQuantPerChannel
(
PrimitiveWithInfer
):
r
"""
r
"""
Simulate the quantize and dequantize operations in training time base on per channel.
Simulate the quantize and dequantize operations in training time base on per channel.
...
@@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
...
@@ -168,53 +203,73 @@ class FakeQuantWithMinMaxPerChannel(PrimitiveWithInfer):
- Tensor, has the same type as input.
- Tensor, has the same type as input.
Examples:
Examples:
>>> fake_quant = P.FakeQuant
WithMinMax
PerChannel()
>>> fake_quant = P.FakeQuantPerChannel()
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> input_x = Tensor(np.array([3, 4, 5, -2, -3, -1]).reshape(3, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _min = Tensor(np.linspace(-2, 2, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> _max = Tensor(np.linspace(8, 12, 12).reshape(3, 2, 2), mindspore.float32)
>>> result = fake_quant(input_x, _min, _max)
>>> result = fake_quant(input_x, _min, _max)
"""
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
channel_axis
=
0
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
def
__init__
(
self
,
training
=
True
):
num_bits
=
8
,
"""init FakeQuantWithMinMaxPerChannel OP"""
ema
=
False
,
ema_decay
=
0.999
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
,
channel_axis
=
1
):
"""init FakeQuantPerChannel OP"""
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' Attr
\'
num_bits
\'
is not support."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' Attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
"min shape[0]"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
validator
.
check_integer
(
"max shape[0]"
,
max_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
"min shape[0]"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"max shape[0]"
,
max_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
return
x_shape
return
x_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
validator
.
check_tensor_type_same
({
"max"
:
max_type
},
valid_types
,
self
.
name
)
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
x_type
return
x_type
class
FakeQuant
WithMinMax
PerChannelGrad
(
PrimitiveWithInfer
):
class
FakeQuantPerChannelGrad
(
PrimitiveWithInfer
):
r
"""
r
"""
Performs grad of FakeQuant
WithMinMaxPerChannel
operation.
Performs grad of FakeQuant
PerChannelGrad
operation.
Examples:
Examples:
>>> fqmmpc_grad = P.FakeQuant
WithMinMax
PerChannelGrad()
>>> fqmmpc_grad = P.FakeQuantPerChannelGrad()
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> input_x = Tensor(np.random.randint(-4, 4, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> dout = Tensor(np.random.randint(-2, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
>>> _min = Tensor(np.random.randint(-8, 2, (2, 3, 4)), mindspore.float32)
...
@@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
...
@@ -224,16 +279,29 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
):
def
__init__
(
self
,
"""init FakeQuantWithMinMaxPerChannel Fill"""
num_bits
=
8
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
channel_axis
=
1
):
"""init FakeQuantPerChannelGrad Fill"""
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
quant_delay
=
validator
.
check_value_type
(
self
.
init_prim_io_names
(
inputs
=
[
'dout'
,
'x'
,
'min'
,
'max'
],
outputs
=
[
'dx'
])
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'dout'
,
'x'
,
'min'
,
'max'
],
outputs
=
[
'dx'
])
def
infer_shape
(
self
,
dout_shape
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
dout_shape
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check
(
"dout shape"
,
dout_shape
,
"x shape"
,
x_shape
)
validator
.
check
(
"dout shape"
,
dout_shape
,
"x shape"
,
x_shape
)
...
@@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
...
@@ -242,10 +310,13 @@ class FakeQuantWithMinMaxPerChannelGrad(PrimitiveWithInfer):
def
infer_dtype
(
self
,
dout_type
,
x_type
,
min_type
,
max_type
):
def
infer_dtype
(
self
,
dout_type
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"dout"
:
dout_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"dout"
:
dout_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
validator
.
check_tensor_type_same
({
"max"
:
max_type
},
valid_types
,
self
.
name
)
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
dout_type
return
dout_type
...
@@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
...
@@ -744,17 +815,14 @@ class BatchNormFold2GradReduce(PrimitiveWithInfer):
return
dout_type
,
dout_type
return
dout_type
,
dout_type
class
FakeQuant
WithMinMax
Update
(
PrimitiveWithInfer
):
class
FakeQuant
MinMaxPerLayer
Update
(
PrimitiveWithInfer
):
r
"""
r
"""
Simulate the quantize and dequantize operations in training time
.
Update min and max value for fake quant per layer op
.
Args:
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
quant_delay (int): Quantilization delay parameter. Before delay step in training time not update
simulate aware quantize funcion. After delay step in training time begin simulate the aware
quantize funcion. Default: 0.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
training (bool): Training the network or not. Default: True.
...
@@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
...
@@ -776,36 +844,121 @@ class FakeQuantWithMinMaxUpdate(PrimitiveWithInfer):
support_quant_bit
=
[
4
,
7
,
8
]
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
):
training
=
True
):
"""init FakeQuant
WithMinMax
OP"""
"""init FakeQuant
MinMaxPerLayerUpdate
OP"""
from
mindspore.ops._op_impl._custom_op
import
correction_mul
,
correction_mul_grad
from
mindspore.ops._op_impl._custom_op
import
correction_mul
,
correction_mul_grad
from
mindspore.ops._op_impl._custom_op
import
fake_quant_with_min_max
,
fake_quant_with_min_max_grad
from
mindspore.ops._op_impl._custom_op
import
fake_quant_with_min_max
,
fake_quant_with_min_max_grad
from
mindspore.ops._op_impl._custom_op
import
fake_quant_with_min_max_update
from
mindspore.ops._op_impl._custom_op
import
fake_quant_with_min_max_update
if
num_bits
not
in
self
.
support_quant_bit
:
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
self
.
quant_delay
=
validator
.
check_value_type
(
'quant_delay'
,
quant_delay
,
(
int
,),
self
.
name
)
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
({
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
validator
.
check_tensor_type_same
({
"max"
:
max_type
},
valid_types
,
self
.
name
)
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
class
FakeQuantMinMaxPerChannelUpdate
(
PrimitiveWithInfer
):
r
"""
Update min and max value for fake quant per layer op.
Args:
num_bits (int) : Number bits for aware quantilization. Default: 8.
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Channel asis for per channel compute. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
- **min** (Tensor) : Value of the min range of the input data x.
- **max** (Tensor) : Value of the max range of the input data x.
Outputs:
- Tensor: Simulate quantize tensor of x.
Examples:
>>> x = Tensor(np.random.rand(3, 16, 5, 5), mstype.float32)
>>> min = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> max = Tensor(np.random.uniform(-1, 1, size=16), mstype.float32)
>>> output_tensor = P.FakeQuantWithMinMax(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
@
prim_attr_register
def
__init__
(
self
,
num_bits
=
8
,
ema
=
False
,
ema_decay
=
0.999
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
num_bits
\'
is not support."
)
if
ema
and
not
ema_decay
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' attr
\'
ema
\'
and
\'
ema_decay
\'
should set together."
)
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
symmetric
=
validator
.
check_value_type
(
'symmetric'
,
symmetric
,
(
bool
,),
self
.
name
)
self
.
narrow_range
=
validator
.
check_value_type
(
'narrow_range'
,
narrow_range
,
(
bool
,),
self
.
name
)
self
.
training
=
validator
.
check_value_type
(
'training'
,
training
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
num_bits
=
validator
.
check_integer
(
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_integer
(
'channel axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min rank"
,
len
(
min_shape
),
1
,
Rel
.
EQ
,
self
.
name
)
return
min_shape
,
max_shape
def
infer_dtype
(
self
,
x_type
,
min_type
,
max_type
):
valid_types
=
(
mstype
.
float16
,
mstype
.
float32
)
validator
.
check_tensor_type_same
(
{
"x"
:
x_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"min"
:
min_type
},
valid_types
,
self
.
name
)
validator
.
check_tensor_type_same
(
{
"max"
:
max_type
},
valid_types
,
self
.
name
)
return
min_type
,
max_type
return
min_type
,
max_type
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录