Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
1c3e5796
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
1c3e5796
编写于
6月 02, 2020
作者:
W
wandongdong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug in quant and correction_mul_grad
上级
75f791d8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
14 addition
and
16 deletion
+14
-16
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu
+3
-4
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+8
-8
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
+2
-2
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
+1
-2
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_impl.cu
浏览文件 @
1c3e5796
...
@@ -21,7 +21,7 @@
...
@@ -21,7 +21,7 @@
#include "fake_quant_impl.cuh"
#include "fake_quant_impl.cuh"
__global__
void
FakeQuantize
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
__global__
void
FakeQuantize
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
bool
symmetric
)
{
const
float
*
nudge_max
,
const
float
*
scale
)
{
float
input_x
=
0.
f
;
float
input_x
=
0.
f
;
int
nudge_input
=
0
;
int
nudge_input
=
0
;
...
@@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
...
@@ -35,7 +35,7 @@ __global__ void FakeQuantize(const float *input, float *output, const int size,
input_x
=
nudge_max
[
0
];
input_x
=
nudge_max
[
0
];
}
}
// clamp shift
// clamp shift
nudge_input
=
floor
((
input_x
-
nudge_min
[
0
])
/
scale
[
0
]
+
0.5
f
);
nudge_input
=
round
((
input_x
-
nudge_min
[
0
])
/
scale
[
0
]
);
// quantize
// quantize
output
[
i
]
=
nudge_input
*
scale
[
0
]
+
nudge_min
[
0
];
output
[
i
]
=
nudge_input
*
scale
[
0
]
+
nudge_min
[
0
];
...
@@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa
...
@@ -99,8 +99,7 @@ __global__ void UpdateInputMinMax(float *input_min, float *input_max, const floa
void
CalFakeQuantize
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
void
CalFakeQuantize
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
bool
symmetric
,
cudaStream_t
cuda_stream
)
{
const
float
*
scale
,
bool
symmetric
,
cudaStream_t
cuda_stream
)
{
FakeQuantize
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
size
,
nudge_min
,
nudge_max
,
scale
,
FakeQuantize
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
size
,
nudge_min
,
nudge_max
,
scale
);
symmetric
);
return
;
return
;
}
}
...
...
mindspore/nn/layer/quant.py
浏览文件 @
1c3e5796
...
@@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter
...
@@ -22,7 +22,7 @@ from mindspore.common.parameter import Parameter
from
mindspore.common.initializer
import
initializer
from
mindspore.common.initializer
import
initializer
from
mindspore.common.tensor
import
Tensor
from
mindspore.common.tensor
import
Tensor
from
mindspore._checkparam
import
check_int_positive
,
check_bool
,
twice
from
mindspore._checkparam
import
check_int_positive
,
check_bool
,
twice
from
mindspore._checkparam
import
Validator
as
validator
from
mindspore._checkparam
import
Validator
as
validator
,
Rel
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.cell
import
Cell
from
mindspore.nn.layer.activation
import
get_activation
from
mindspore.nn.layer.activation
import
get_activation
import
mindspore.context
as
context
import
mindspore.context
as
context
...
@@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell):
...
@@ -207,7 +207,7 @@ class FakeQuantWithMinMaxD(Cell):
class
FakeQuantWithMinMax
(
Cell
):
class
FakeQuantWithMinMax
(
Cell
):
r
"""
r
"""
Aware Quantization
training
op. This OP provide Fake quantization observer function on data with min and max.
Aware Quantization op. This OP provide Fake quantization observer function on data with min and max.
Args:
Args:
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
min_init (int, list): The dimension of channel or 1(layer). Default: -6.
...
@@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell):
...
@@ -243,8 +243,7 @@ class FakeQuantWithMinMax(Cell):
out_channels
=
1
,
out_channels
=
1
,
quant_delay
=
0
,
quant_delay
=
0
,
symmetric
=
False
,
symmetric
=
False
,
narrow_range
=
False
,
narrow_range
=
False
):
training
=
True
):
"""init FakeQuantWithMinMax layer"""
"""init FakeQuantWithMinMax layer"""
super
(
FakeQuantWithMinMax
,
self
).
__init__
()
super
(
FakeQuantWithMinMax
,
self
).
__init__
()
...
@@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell):
...
@@ -258,7 +257,6 @@ class FakeQuantWithMinMax(Cell):
self
.
quant_delay
=
quant_delay
self
.
quant_delay
=
quant_delay
self
.
symmetric
=
symmetric
self
.
symmetric
=
symmetric
self
.
narrow_range
=
narrow_range
self
.
narrow_range
=
narrow_range
self
.
training
=
training
if
per_channel
:
if
per_channel
:
min_array
=
np
.
array
([
self
.
min_init
for
i
in
range
(
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
min_array
=
np
.
array
([
self
.
min_init
for
i
in
range
(
0
,
self
.
out_channels
)]).
astype
(
np
.
float32
)
...
@@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -422,11 +420,13 @@ class Conv2dBatchNormQuant(Cell):
self
.
per_channel
=
per_channel
self
.
per_channel
=
per_channel
self
.
symmetric
=
symmetric
self
.
symmetric
=
symmetric
self
.
narrow_range
=
narrow_range
self
.
narrow_range
=
narrow_range
self
.
channel_axis
=
int
(
group
>
1
)
self
.
is_gpu
=
context
.
get_context
(
'device_target'
)
==
"GPU"
# initialize convolution op and Parameter
# initialize convolution op and Parameter
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
and
group
>
1
:
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
and
group
>
1
:
validator
.
check_integer
(
'group'
,
group
,
'in_channels'
,
in_channels
,
'Conv2dBatchNormQuant'
)
validator
.
check_integer
(
'group'
,
group
,
in_channels
,
Rel
.
EQ
,
'Conv2dBatchNormQuant'
)
validator
.
check_integer
(
'group'
,
group
,
'in_channels'
,
out_channels
,
'Conv2dBatchNormQuant'
)
validator
.
check_integer
(
'group'
,
group
,
out_channels
,
Rel
.
EQ
,
'Conv2dBatchNormQuant'
)
self
.
conv
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
self
.
conv
=
P
.
DepthwiseConv2dNative
(
channel_multiplier
=
1
,
kernel_size
=
self
.
kernel_size
,
kernel_size
=
self
.
kernel_size
,
pad_mode
=
pad_mode
,
pad_mode
=
pad_mode
,
...
@@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell):
...
@@ -472,7 +472,7 @@ class Conv2dBatchNormQuant(Cell):
symmetric
=
symmetric
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
)
narrow_range
=
narrow_range
)
self
.
batchnorm_fold
=
BatchNormFoldCell
(
epsilon
=
eps
,
momentum
=
momentum
,
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold
=
BatchNormFoldCell
(
epsilon
=
eps
,
momentum
=
momentum
,
freeze_bn
=
freeze_bn
)
self
.
correct_mul
=
P
.
CorrectionMul
()
self
.
correct_mul
=
P
.
CorrectionMul
(
self
.
channel_axis
)
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
batchnorm_fold2_train
=
P
.
BatchNormFold2_D
(
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold2_train
=
P
.
BatchNormFold2_D
(
freeze_bn
=
freeze_bn
)
self
.
batchnorm_fold2_infer
=
P
.
BatchNormFold2_D
(
freeze_bn
=
0
)
self
.
batchnorm_fold2_infer
=
P
.
BatchNormFold2_D
(
freeze_bn
=
0
)
...
...
mindspore/ops/_op_impl/_custom_op/correction_mul_grad.py
浏览文件 @
1c3e5796
...
@@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
...
@@ -93,8 +93,8 @@ def correction_mul_grad(dout, x, batch_std, running_std, dx, d_batch_std, channe
util
.
check_dtype_rule
(
inp_dtype_dout
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_dout
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_x
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_x
,
(
"float16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_batch_std
,
(
"float
32"
,
))
util
.
check_dtype_rule
(
inp_dtype_batch_std
,
(
"float
16"
,
"float32"
))
util
.
check_dtype_rule
(
inp_dtype_running_std
,
(
"float
32"
,
))
util
.
check_dtype_rule
(
inp_dtype_running_std
,
(
"float
16"
,
"float32"
))
util
.
compare_tensor_dict_key
(
dout
,
x
,
"dtype"
)
util
.
compare_tensor_dict_key
(
dout
,
x
,
"dtype"
)
util
.
compare_tensor_dict_key
(
dout
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
dout
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
dx
,
x
,
"shape"
)
util
.
compare_tensor_dict_key
(
dx
,
x
,
"shape"
)
...
...
mindspore/ops/_op_impl/_custom_op/fake_quant_with_min_max.py
浏览文件 @
1c3e5796
...
@@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
...
@@ -80,8 +80,7 @@ def fake_quant_with_min_max_vars_ema_compute(x, min_val, max_val, y, quant_min,
# FakeQuant
# FakeQuant
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max
,
te
.
lang
.
cce
.
vmax
(
nudge_min
,
x
))
input_x
=
te
.
lang
.
cce
.
vmin
(
nudge_max
,
te
.
lang
.
cce
.
vmax
(
nudge_min
,
x
))
nudge_input
=
te
.
lang
.
cce
.
floor
(
te
.
lang
.
cce
.
vadds
(
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min
),
scale
),
nudge_input
=
te
.
lang
.
cce
.
round
(
te
.
lang
.
cce
.
vdiv
(
te
.
lang
.
cce
.
vsub
(
input_x
,
nudge_min
),
scale
))
0.5
))
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale
),
nudge_min
)
res
=
te
.
lang
.
cce
.
vadd
(
te
.
lang
.
cce
.
vmul
(
nudge_input
,
scale
),
nudge_min
)
return
res
return
res
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录