Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e9ee59c7
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e9ee59c7
编写于
6月 22, 2020
作者:
C
chenzupeng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add perchannel quant train
上级
4bbd4414
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
59 addition
and
36 deletion
+59
-36
example/mobilenetv2_quant/Readme.md
example/mobilenetv2_quant/Readme.md
+0
-1
example/mobilenetv2_quant/src/config.py
example/mobilenetv2_quant/src/config.py
+2
-2
example/mobilenetv2_quant/src/mobilenetV2_quant.py
example/mobilenetv2_quant/src/mobilenetV2_quant.py
+15
-9
example/resnet50_quant/README.md
example/resnet50_quant/README.md
+2
-2
example/resnet50_quant/models/resnet_quant.py
example/resnet50_quant/models/resnet_quant.py
+12
-6
example/resnet50_quant/src/config.py
example/resnet50_quant/src/config.py
+2
-2
mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
...op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
+8
-4
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
+9
-5
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
...ore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
+9
-5
未找到文件。
example/mobilenetv2_quant/Readme.md
浏览文件 @
e9ee59c7
...
@@ -47,7 +47,6 @@ Dataset used: imagenet
...
@@ -47,7 +47,6 @@ Dataset used: imagenet
├──
eval
.
py
├──
eval
.
py
```
```
Notation: Current hyperparameters only test on 4 cards while training, if want to use 8 cards for training, should change parameters like learning rate in 'src/config.py'.
## Training process
## Training process
...
...
example/mobilenetv2_quant/src/config.py
浏览文件 @
e9ee59c7
...
@@ -22,10 +22,10 @@ config_ascend = ed({
...
@@ -22,10 +22,10 @@ config_ascend = ed({
"image_height"
:
224
,
"image_height"
:
224
,
"image_width"
:
224
,
"image_width"
:
224
,
"batch_size"
:
192
,
"batch_size"
:
192
,
"epoch_size"
:
4
0
,
"epoch_size"
:
6
0
,
"start_epoch"
:
200
,
"start_epoch"
:
200
,
"warmup_epochs"
:
1
,
"warmup_epochs"
:
1
,
"lr"
:
0.
15
,
"lr"
:
0.
3
,
"momentum"
:
0.9
,
"momentum"
:
0.9
,
"weight_decay"
:
4e-5
,
"weight_decay"
:
4e-5
,
"label_smooth"
:
0.1
,
"label_smooth"
:
0.1
,
...
...
example/mobilenetv2_quant/src/mobilenetV2_quant.py
浏览文件 @
e9ee59c7
...
@@ -20,7 +20,8 @@ from mindspore.ops.operations import TensorAdd
...
@@ -20,7 +20,8 @@ from mindspore.ops.operations import TensorAdd
__all__
=
[
'mobilenet_v2_quant'
]
__all__
=
[
'mobilenet_v2_quant'
]
_ema_decay
=
0.999
_ema_decay
=
0.999
_symmetric
=
False
_symmetric
=
True
_per_channel
=
True
def
_make_divisible
(
v
,
divisor
,
min_value
=
None
):
def
_make_divisible
(
v
,
divisor
,
min_value
=
None
):
...
@@ -77,10 +78,10 @@ class ConvBNReLU(nn.Cell):
...
@@ -77,10 +78,10 @@ class ConvBNReLU(nn.Cell):
super
(
ConvBNReLU
,
self
).
__init__
()
super
(
ConvBNReLU
,
self
).
__init__
()
padding
=
(
kernel_size
-
1
)
//
2
padding
=
(
kernel_size
-
1
)
//
2
conv
=
nn
.
Conv2dBatchNormQuant
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
pad_mode
=
'pad'
,
padding
=
padding
,
conv
=
nn
.
Conv2dBatchNormQuant
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
pad_mode
=
'pad'
,
padding
=
padding
,
group
=
groups
)
group
=
groups
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
)
layers
=
[
conv
,
nn
.
ReLU
()]
layers
=
[
conv
,
nn
.
ReLU
()]
self
.
features
=
nn
.
SequentialCell
(
layers
)
self
.
features
=
nn
.
SequentialCell
(
layers
)
self
.
fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
_symmetric
,
min_init
=
0
)
self
.
fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
min_init
=
0
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
output
=
self
.
features
(
x
)
output
=
self
.
features
(
x
)
...
@@ -119,12 +120,13 @@ class InvertedResidual(nn.Cell):
...
@@ -119,12 +120,13 @@ class InvertedResidual(nn.Cell):
# dw
# dw
ConvBNReLU
(
hidden_dim
,
hidden_dim
,
stride
=
stride
,
groups
=
hidden_dim
),
ConvBNReLU
(
hidden_dim
,
hidden_dim
,
stride
=
stride
,
groups
=
hidden_dim
),
# pw-linear
# pw-linear
nn
.
Conv2dBatchNormQuant
(
hidden_dim
,
oup
,
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
'pad'
,
padding
=
0
,
group
=
1
),
nn
.
Conv2dBatchNormQuant
(
hidden_dim
,
oup
,
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
'pad'
,
padding
=
0
,
group
=
1
,
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
_symmetric
)
per_channel
=
_per_channel
,
symmetric
=
_symmetric
),
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
)
])
])
self
.
conv
=
nn
.
SequentialCell
(
layers
)
self
.
conv
=
nn
.
SequentialCell
(
layers
)
self
.
add
=
TensorAdd
()
self
.
add
=
TensorAdd
()
self
.
add_fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
_symmetric
)
self
.
add_fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
identity
=
x
identity
=
x
...
@@ -175,7 +177,7 @@ class MobileNetV2Quant(nn.Cell):
...
@@ -175,7 +177,7 @@ class MobileNetV2Quant(nn.Cell):
# building first layer
# building first layer
input_channel
=
_make_divisible
(
input_channel
*
width_mult
,
round_nearest
)
input_channel
=
_make_divisible
(
input_channel
*
width_mult
,
round_nearest
)
self
.
out_channels
=
_make_divisible
(
last_channel
*
max
(
1.0
,
width_mult
),
round_nearest
)
self
.
out_channels
=
_make_divisible
(
last_channel
*
max
(
1.0
,
width_mult
),
round_nearest
)
self
.
input_fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
_symmetric
)
self
.
input_fake
=
nn
.
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
)
features
=
[
ConvBNReLU
(
3
,
input_channel
,
stride
=
2
)]
features
=
[
ConvBNReLU
(
3
,
input_channel
,
stride
=
2
)]
# building inverted residual blocks
# building inverted residual blocks
for
t
,
c
,
n
,
s
in
self
.
cfgs
:
for
t
,
c
,
n
,
s
in
self
.
cfgs
:
...
@@ -189,8 +191,12 @@ class MobileNetV2Quant(nn.Cell):
...
@@ -189,8 +191,12 @@ class MobileNetV2Quant(nn.Cell):
# make it nn.CellList
# make it nn.CellList
self
.
features
=
nn
.
SequentialCell
(
features
)
self
.
features
=
nn
.
SequentialCell
(
features
)
# mobilenet head
# mobilenet head
head
=
([
GlobalAvgPooling
(),
nn
.
Dense
(
self
.
out_channels
,
num_classes
,
has_bias
=
True
)]
if
not
has_dropout
else
head
=
([
GlobalAvgPooling
(),
[
GlobalAvgPooling
(),
nn
.
Dropout
(
0.2
),
nn
.
Dense
(
self
.
out_channels
,
num_classes
,
has_bias
=
True
)])
nn
.
DenseQuant
(
self
.
out_channels
,
num_classes
,
has_bias
=
True
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
)]
if
not
has_dropout
else
[
GlobalAvgPooling
(),
nn
.
Dropout
(
0.2
),
nn
.
DenseQuant
(
self
.
out_channels
,
num_classes
,
has_bias
=
True
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
)])
self
.
head
=
nn
.
SequentialCell
(
head
)
self
.
head
=
nn
.
SequentialCell
(
head
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
...
...
example/resnet50_quant/README.md
浏览文件 @
e9ee59c7
...
@@ -51,7 +51,7 @@ Parameters for both training and inference can be set in config.py.
...
@@ -51,7 +51,7 @@ Parameters for both training and inference can be set in config.py.
"loss_scale": 1024, # loss scale
"loss_scale": 1024, # loss scale
"momentum": 0.9, # momentum optimizer
"momentum": 0.9, # momentum optimizer
"weight_decay": 1e-4, # weight decay
"weight_decay": 1e-4, # weight decay
"epoch_size": 1
1
0, # only valid for taining, which is always 1 for inference
"epoch_size": 1
2
0, # only valid for taining, which is always 1 for inference
"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint
"pretrained_epoch_size": 90, # epoch size that model has been trained before load pretrained checkpoint
"buffer_size": 1000, # number of queue size in data preprocessing
"buffer_size": 1000, # number of queue size in data preprocessing
"image_height": 224, # image height
"image_height": 224, # image height
...
@@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py.
...
@@ -65,7 +65,7 @@ Parameters for both training and inference can be set in config.py.
"label_smooth": True, # label smooth
"label_smooth": True, # label smooth
"label_smooth_factor": 0.1, # label smooth factor
"label_smooth_factor": 0.1, # label smooth factor
"lr_init": 0, # initial learning rate
"lr_init": 0, # initial learning rate
"lr_max": 0.
1
, # maximum learning rate
"lr_max": 0.
005
, # maximum learning rate
```
```
## Running the example
## Running the example
...
...
example/resnet50_quant/models/resnet_quant.py
浏览文件 @
e9ee59c7
...
@@ -22,6 +22,7 @@ from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant
...
@@ -22,6 +22,7 @@ from mindspore.nn import FakeQuantWithMinMax, Conv2dBatchNormQuant
_ema_decay
=
0.999
_ema_decay
=
0.999
_symmetric
=
False
_symmetric
=
False
_fake
=
True
_fake
=
True
_per_channel
=
True
def
_weight_variable
(
shape
,
factor
=
0.01
):
def
_weight_variable
(
shape
,
factor
=
0.01
):
init_value
=
np
.
random
.
randn
(
*
shape
).
astype
(
np
.
float32
)
*
factor
init_value
=
np
.
random
.
randn
(
*
shape
).
astype
(
np
.
float32
)
*
factor
...
@@ -85,7 +86,7 @@ class ConvBNReLU(nn.Cell):
...
@@ -85,7 +86,7 @@ class ConvBNReLU(nn.Cell):
super
(
ConvBNReLU
,
self
).
__init__
()
super
(
ConvBNReLU
,
self
).
__init__
()
padding
=
(
kernel_size
-
1
)
//
2
padding
=
(
kernel_size
-
1
)
//
2
conv
=
Conv2dBatchNormQuant
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
pad_mode
=
'pad'
,
padding
=
padding
,
conv
=
Conv2dBatchNormQuant
(
in_planes
,
out_planes
,
kernel_size
,
stride
,
pad_mode
=
'pad'
,
padding
=
padding
,
group
=
groups
,
fake
=
_fake
)
group
=
groups
,
fake
=
_fake
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
)
layers
=
[
conv
,
nn
.
ReLUQuant
()]
if
_fake
else
[
conv
,
nn
.
ReLU
()]
layers
=
[
conv
,
nn
.
ReLUQuant
()]
if
_fake
else
[
conv
,
nn
.
ReLU
()]
self
.
features
=
nn
.
SequentialCell
(
layers
)
self
.
features
=
nn
.
SequentialCell
(
layers
)
...
@@ -119,10 +120,13 @@ class ResidualBlock(nn.Cell):
...
@@ -119,10 +120,13 @@ class ResidualBlock(nn.Cell):
channel
=
out_channel
//
self
.
expansion
channel
=
out_channel
//
self
.
expansion
self
.
conv1
=
ConvBNReLU
(
in_channel
,
channel
,
kernel_size
=
1
,
stride
=
1
)
self
.
conv1
=
ConvBNReLU
(
in_channel
,
channel
,
kernel_size
=
1
,
stride
=
1
)
self
.
conv2
=
ConvBNReLU
(
channel
,
channel
,
kernel_size
=
3
,
stride
=
stride
)
self
.
conv2
=
ConvBNReLU
(
channel
,
channel
,
kernel_size
=
3
,
stride
=
stride
)
self
.
conv3
=
nn
.
SequentialCell
([
Conv2dBatchNormQuant
(
channel
,
out_channel
,
fake
=
_fake
,
self
.
conv3
=
nn
.
SequentialCell
([
Conv2dBatchNormQuant
(
channel
,
out_channel
,
fake
=
_fake
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
,
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
),
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
),
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
False
)
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
False
)
])
if
_fake
else
Conv2dBatchNormQuant
(
channel
,
out_channel
,
fake
=
_fake
,
])
if
_fake
else
Conv2dBatchNormQuant
(
channel
,
out_channel
,
fake
=
_fake
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
,
kernel_size
=
1
,
stride
=
1
,
kernel_size
=
1
,
stride
=
1
,
pad_mode
=
'same'
,
padding
=
0
)
pad_mode
=
'same'
,
padding
=
0
)
...
@@ -134,18 +138,22 @@ class ResidualBlock(nn.Cell):
...
@@ -134,18 +138,22 @@ class ResidualBlock(nn.Cell):
if
self
.
down_sample
:
if
self
.
down_sample
:
self
.
down_sample_layer
=
nn
.
SequentialCell
([
Conv2dBatchNormQuant
(
in_channel
,
out_channel
,
self
.
down_sample_layer
=
nn
.
SequentialCell
([
Conv2dBatchNormQuant
(
in_channel
,
out_channel
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
,
kernel_size
=
1
,
stride
=
stride
,
kernel_size
=
1
,
stride
=
stride
,
pad_mode
=
'same'
,
padding
=
0
),
pad_mode
=
'same'
,
padding
=
0
),
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
False
)
symmetric
=
False
)
])
if
_fake
else
Conv2dBatchNormQuant
(
in_channel
,
out_channel
,
])
if
_fake
else
Conv2dBatchNormQuant
(
in_channel
,
out_channel
,
fake
=
_fake
,
fake
=
_fake
,
per_channel
=
_per_channel
,
symmetric
=
_symmetric
,
kernel_size
=
1
,
kernel_size
=
1
,
stride
=
stride
,
stride
=
stride
,
pad_mode
=
'same'
,
pad_mode
=
'same'
,
padding
=
0
)
padding
=
0
)
self
.
add
=
P
.
TensorAdd
()
self
.
add
=
P
.
TensorAdd
()
self
.
fake
=
FakeQuantWithMinMax
(
ema
=
True
,
ema_decay
=
_ema_decay
,
symmetric
=
False
)
self
.
relu
=
nn
.
ReLUQuant
()
if
_fake
else
P
.
ReLU
(
)
def
construct
(
self
,
x
):
def
construct
(
self
,
x
):
identity
=
x
identity
=
x
...
@@ -157,9 +165,7 @@ class ResidualBlock(nn.Cell):
...
@@ -157,9 +165,7 @@ class ResidualBlock(nn.Cell):
identity
=
self
.
down_sample_layer
(
identity
)
identity
=
self
.
down_sample_layer
(
identity
)
out
=
self
.
add
(
out
,
identity
)
out
=
self
.
add
(
out
,
identity
)
out
=
P
.
ReLU
()(
out
)
out
=
self
.
relu
(
out
)
if
_fake
:
out
=
self
.
fake
(
out
)
return
out
return
out
...
...
example/resnet50_quant/src/config.py
浏览文件 @
e9ee59c7
...
@@ -23,7 +23,7 @@ config = ed({
...
@@ -23,7 +23,7 @@ config = ed({
"loss_scale"
:
1024
,
"loss_scale"
:
1024
,
"momentum"
:
0.9
,
"momentum"
:
0.9
,
"weight_decay"
:
1e-4
,
"weight_decay"
:
1e-4
,
"epoch_size"
:
1
1
0
,
"epoch_size"
:
1
2
0
,
"pretrained_epoch_size"
:
90
,
"pretrained_epoch_size"
:
90
,
"buffer_size"
:
1000
,
"buffer_size"
:
1000
,
"image_height"
:
224
,
"image_height"
:
224
,
...
@@ -37,6 +37,6 @@ config = ed({
...
@@ -37,6 +37,6 @@ config = ed({
"use_label_smooth"
:
True
,
"use_label_smooth"
:
True
,
"label_smooth_factor"
:
0.1
,
"label_smooth_factor"
:
0.1
,
"lr_init"
:
0
,
"lr_init"
:
0
,
"lr_max"
:
0.
1
"lr_max"
:
0.
005
})
})
mindspore/ops/_op_impl/_custom_op/fake_quant_minmax_perchannel_update.py
浏览文件 @
e9ee59c7
...
@@ -91,11 +91,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
...
@@ -91,11 +91,15 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape
[
0
]
!=
min_shape
[
0
]
and
x_shape
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -122,7 +126,7 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
...
@@ -122,7 +126,7 @@ def fake_quant_min_max_per_channel_update(x, min_val, max_val, min_up, max_up,
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
max_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"max_val"
,
dtype
=
x_dtype
)
res_list
=
fake_quant_min_max_per_channel_update_compute
(
input_data
,
min_data
,
max_data
,
res_list
=
fake_quant_min_max_per_channel_update_compute
(
input_data
,
min_data
,
max_data
,
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
,
kernel_name
)
ema
,
ema_decay
,
quant_min
,
quant_max
,
training
,
channel_axis
_
,
kernel_name
)
with
tvm
.
target
.
cce
():
with
tvm
.
target
.
cce
():
sch
=
generic
.
auto_schedule
(
res_list
)
sch
=
generic
.
auto_schedule
(
res_list
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel.py
浏览文件 @
e9ee59c7
...
@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
...
@@ -99,11 +99,15 @@ def fake_quant_perchannel(x, min_val, max_val, y,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape_
[
0
]
!=
min_shape
[
0
]
and
x_shape_
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
...
@@ -126,8 +130,8 @@ def fake_quant_perchannel(x, min_val, max_val, y,
quant_min
=
quant_min
+
1
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
shape_c
[
channel_axis
_
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
if
x_format
==
"NC1HWC0"
and
channel_axis
_
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
shape_c
=
min_val
.
get
(
"shape"
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
min_data
=
tvm
.
placeholder
(
shape_c
,
name
=
"min_val"
,
dtype
=
x_dtype
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_perchannel_grad.py
浏览文件 @
e9ee59c7
...
@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
...
@@ -124,11 +124,15 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
min_dtype
=
min_val
.
get
(
"dtype"
)
min_dtype
=
min_val
.
get
(
"dtype"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_shape
=
max_val
.
get
(
"ori_shape"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
max_dtype
=
max_val
.
get
(
"dtype"
)
# for Dense weight quant, 2d[co,ci] -> 4d[1,co,ci,1], channel_axis_ need change to 1.
if
channel_axis
==
0
and
x_shape_
[
0
]
!=
min_shape
[
0
]
and
x_shape_
[
1
]
==
min_shape
[
0
]:
channel_axis_
=
1
else
:
channel_axis_
=
channel_axis
util
.
check_kernel_name
(
kernel_name
)
util
.
check_kernel_name
(
kernel_name
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
x_shape
)
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
min_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
])
util
.
check_shape_rule
(
max_shape
,
1
,
1
,
x_shape_
[
channel_axis
_
])
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
x_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
min_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
util
.
check_tensor_shape_size
(
max_shape
)
...
@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
...
@@ -151,8 +155,8 @@ def fake_quant_perchannel_grad(dout, x, min_val, max_val, dx,
quant_min
=
quant_min
+
1
quant_min
=
quant_min
+
1
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
=
[
1
]
*
len
(
x_shape
)
shape_c
[
channel_axis
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
shape_c
[
channel_axis
_
]
=
min_val
.
get
(
"ori_shape"
)[
0
]
if
x_format
==
"NC1HWC0"
and
channel_axis
==
1
:
if
x_format
==
"NC1HWC0"
and
channel_axis
_
==
1
:
shape_c
=
min_val
.
get
(
"shape"
)
shape_c
=
min_val
.
get
(
"shape"
)
dout_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
dout_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"dout"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
input_data
=
tvm
.
placeholder
(
x_shape
,
name
=
"x"
,
dtype
=
x_dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录