Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wmsofts
mindspore
提交
684ecac9
M
mindspore
项目概览
wmsofts
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
684ecac9
编写于
6月 30, 2020
作者:
C
chenzomi
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
rebase master to r0.5 for quantizaiton aware training
上级
412e4580
变更
26
展开全部
隐藏空白更改
内联
并排
Showing
26 changed file
with
2743 addition
and
185 deletion
+2743
-185
build.sh
build.sh
+1
-1
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu
.../ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu
+138
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh
...ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh
+34
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu
...re/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu
+111
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh
...e/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh
+31
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
+87
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
+29
-0
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc
...csrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc
+42
-75
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h
...ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h
+7
-15
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc
...kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc
+23
-35
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h
.../kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h
+1
-5
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc
.../ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.cc
+143
-0
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h
...e/ccsrc/kernel/gpu/quant/fake_quant_perlayer_gpu_kernel.h
+8
-14
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc
...c/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.cc
+133
-0
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h
...rc/kernel/gpu/quant/fake_quant_perlayer_grad_gpu_kernel.h
+7
-10
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
...c/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
+96
-0
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
...rc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
+55
-0
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
...src/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
+93
-0
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
...csrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
+54
-0
mindspore/nn/layer/quant.py
mindspore/nn/layer/quant.py
+1
-0
mindspore/ops/operations/_quant_ops.py
mindspore/ops/operations/_quant_ops.py
+24
-15
mindspore/train/quant/quant.py
mindspore/train/quant/quant.py
+20
-15
tests/st/ops/gpu/test_fake_quant_perchannel.py
tests/st/ops/gpu/test_fake_quant_perchannel.py
+625
-0
tests/st/ops/gpu/test_fake_quant_perchannel_grad.py
tests/st/ops/gpu/test_fake_quant_perchannel_grad.py
+373
-0
tests/st/ops/gpu/test_fake_quant_perlayer.py
tests/st/ops/gpu/test_fake_quant_perlayer.py
+386
-0
tests/st/ops/gpu/test_fake_quant_perlayer_grad.py
tests/st/ops/gpu/test_fake_quant_perlayer_grad.py
+221
-0
未找到文件。
build.sh
浏览文件 @
684ecac9
...
...
@@ -252,7 +252,7 @@ checkopts()
done
}
checkopts
"
$@
"
echo
"---------------- mind
s
pore: build start ----------------"
echo
"---------------- mind
S
pore: build start ----------------"
mkdir
-pv
"
${
BUILD_PATH
}
/package/mindspore/lib"
git submodule update
--init
graphengine
if
[[
"X
$ENABLE_AKG
"
=
"Xon"
]]
&&
[[
"X
$ENABLE_D
"
=
"Xon"
]]
;
then
...
...
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cu
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "fake_quant_perchannel_impl.cuh"
/**
* Find the nudge min, max and scale value as output.
* @param input_min array
* @param input_max array
* @param quant_min 1 << bit -1
* @param quant_max 0
* @param nudge_min array
* @param nudge_max array
* @param scale array
* @param channel_num
* @return
*/
__global__
void
NudgeMinMaxPerChannel
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
int
channel_num
,
const
bool
symmetric
)
{
float
zp_from_min
=
0.
f
;
float
nudge_zp
=
0.
f
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
channel_num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
symmetric
)
{
input_max
[
i
]
=
abs
(
input_min
[
0
])
<
input_max
[
i
]
?
input_max
[
i
]
:
-
input_min
[
i
];
input_min
[
i
]
=
abs
(
input_min
[
i
])
<
input_max
[
i
]
?
-
input_max
[
i
]
:
input_min
[
i
];
}
if
((
quant_max
-
quant_min
)
==
0
||
(
input_max
[
i
]
-
input_min
[
i
])
==
0
)
{
scale
[
i
]
=
0.
f
;
zp_from_min
=
0.
f
;
}
else
{
scale
[
i
]
=
(
input_max
[
i
]
-
input_min
[
i
])
/
(
quant_max
-
quant_min
);
zp_from_min
=
quant_min
-
input_min
[
i
]
/
scale
[
i
];
}
if
(
zp_from_min
<=
quant_min
)
{
nudge_zp
=
quant_min
;
}
else
if
(
zp_from_min
>=
quant_max
)
{
nudge_zp
=
quant_max
;
}
else
{
nudge_zp
=
round
(
zp_from_min
);
}
nudge_min
[
i
]
=
(
quant_min
-
nudge_zp
)
*
(
scale
[
i
]);
nudge_max
[
i
]
=
(
quant_max
-
nudge_zp
)
*
(
scale
[
i
]);
}
}
void
CalNudgePerChannel
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
const
int
channel_num
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
)
{
NudgeMinMaxPerChannel
<<<
GET_BLOCKS
(
channel_num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input_min
,
input_max
,
quant_min
,
quant_max
,
nudge_min
,
nudge_max
,
scale
,
channel_num
,
symmetric
);
}
/**
* Calulate fake quant output accroding by nudge min, nudge max, nudge scale.
* @param input - array
* @param output - array
* @param total_size - int, purpose for cal the per chanel number in filters
* @param channel_size - int, purpose for cal the per channel number in filters
* @param nudge_min - array
* @param nudge_max - array
* @param scale - array
* @return
*/
__global__
void
FakeQuantPerChannel
(
const
float
*
input
,
float
*
output
,
const
int
total_size
,
const
int
channel_size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
)
{
float
input_x
=
0.
f
;
int
nudge_input
=
0
;
int
channel_idx
=
0
;
int
per_channel_num
=
total_size
/
channel_size
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
total_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
input_x
=
input
[
i
];
channel_idx
=
floor
(
static_cast
<
double
>
(
i
)
/
static_cast
<
double
>
(
per_channel_num
));
// clamp input x
if
(
input_x
<
nudge_min
[
channel_idx
])
{
input_x
=
nudge_min
[
channel_idx
];
}
if
(
input_x
>
nudge_max
[
channel_idx
])
{
input_x
=
nudge_max
[
channel_idx
];
}
// clamp shift
nudge_input
=
floor
((
input_x
-
nudge_min
[
channel_idx
])
/
scale
[
channel_idx
]
+
0.5
f
);
// quantize
output
[
i
]
=
nudge_input
*
scale
[
channel_idx
]
+
nudge_min
[
channel_idx
];
}
}
void
CalFakeQuantPerChannel
(
const
float
*
input
,
float
*
output
,
const
int
total_size
,
const
int
channel_size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
cudaStream_t
cuda_stream
)
{
FakeQuantPerChannel
<<<
GET_BLOCKS
(
total_size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
total_size
,
channel_size
,
nudge_min
,
nudge_max
,
scale
);
}
__global__
void
FakeQuantPerChannelGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
total_size
,
const
int
channel_size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
)
{
int
channel_idx
=
0
;
int
per_channel_num
=
total_size
/
channel_size
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
total_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
channel_idx
=
floor
(
static_cast
<
double
>
(
i
)
/
static_cast
<
double
>
(
per_channel_num
));
if
(
input
[
i
]
<
nudge_min
[
channel_idx
]
||
input
[
i
]
>
nudge_max
[
channel_idx
])
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
gradient
[
i
];
}
}
}
void
CalFakeQuantPerChannelGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
total_num
,
const
int
channel_num
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
cudaStream_t
cuda_stream
)
{
FakeQuantPerChannelGrad
<<<
GET_BLOCKS
(
channel_num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
gradient
,
output
,
total_num
,
channel_num
,
nudge_min
,
nudge_max
);
}
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
#include "device/gpu/cuda_common.h"
void
CalNudgePerChannel
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
const
int
channel_num
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
);
void
CalFakeQuantPerChannel
(
const
float
*
input
,
float
*
output
,
const
int
total_num
,
const
int
channel_num
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
cudaStream_t
cuda_stream
);
void
CalFakeQuantPerChannelGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
total_num
,
const
int
channel_num
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERCHANNEL_H_
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cu
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/pair.h>
#include "fake_quant_perlayer_impl.cuh"
__global__
void
FakeQuantPerLayer
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
)
{
float
input_x
=
0.
f
;
int
nudge_input
=
0
;
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
input_x
=
input
[
i
];
// clamp input x
if
(
input_x
<
nudge_min
[
0
])
{
input_x
=
nudge_min
[
0
];
}
if
(
input_x
>
nudge_max
[
0
])
{
input_x
=
nudge_max
[
0
];
}
// clamp shift
nudge_input
=
round
((
input_x
-
nudge_min
[
0
])
/
scale
[
0
]);
// quantize
output
[
i
]
=
nudge_input
*
scale
[
0
]
+
nudge_min
[
0
];
}
return
;
}
__global__
void
FakeQuantPerLayerGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
if
(
input
[
i
]
<
nudge_min
[
0
]
||
input
[
i
]
>
nudge_max
[
0
])
{
output
[
i
]
=
0
;
}
else
{
output
[
i
]
=
gradient
[
i
];
}
}
return
;
}
__global__
void
NudgeMinMaxPerLayer
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
const
bool
symmetric
)
{
float
zp_from_min
=
0.
f
;
scale
[
0
]
=
0.
f
;
nudge_max
[
0
]
=
0.
f
;
nudge_min
[
0
]
=
0.
f
;
if
(
symmetric
)
{
input_max
[
0
]
=
abs
(
input_min
[
0
])
<
input_max
[
0
]
?
input_max
[
0
]
:
-
input_min
[
0
];
input_min
[
0
]
=
abs
(
input_min
[
0
])
<
input_max
[
0
]
?
-
input_max
[
0
]
:
input_min
[
0
];
}
if
((
quant_max
-
quant_min
)
==
0
||
(
input_max
[
0
]
-
input_min
[
0
])
==
0
)
{
scale
[
0
]
=
0.
f
;
zp_from_min
=
0.
f
;
}
else
{
scale
[
0
]
=
(
input_max
[
0
]
-
input_min
[
0
])
/
(
quant_max
-
quant_min
);
zp_from_min
=
quant_min
-
input_min
[
0
]
/
scale
[
0
];
}
float
nudge_zp
=
0.
f
;
if
(
zp_from_min
<=
quant_min
)
{
nudge_zp
=
quant_min
;
}
else
if
(
zp_from_min
>=
quant_max
)
{
nudge_zp
=
quant_max
;
}
else
{
nudge_zp
=
round
(
zp_from_min
);
}
nudge_min
[
0
]
=
(
quant_min
-
nudge_zp
)
*
(
scale
[
0
]);
nudge_max
[
0
]
=
(
quant_max
-
nudge_zp
)
*
(
scale
[
0
]);
return
;
}
void
CalFakeQuantPerLayer
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
cudaStream_t
cuda_stream
)
{
FakeQuantPerLayer
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
size
,
nudge_min
,
nudge_max
,
scale
);
return
;
}
void
CalFakeQuantPerLayerGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
cudaStream_t
cuda_stream
)
{
FakeQuantPerLayerGrad
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
gradient
,
output
,
size
,
nudge_min
,
nudge_max
);
return
;
}
void
CalNudgePerLayer
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
)
{
NudgeMinMaxPerLayer
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
input_min
,
input_max
,
quant_min
,
quant_max
,
nudge_min
,
nudge_max
,
scale
,
symmetric
);
return
;
}
mindspore/ccsrc/kernel/gpu/cuda_impl/fake_quant_perlayer_impl.cuh
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
#include "device/gpu/cuda_common.h"
void
CalNudgePerLayer
(
float
*
input_min
,
float
*
input_max
,
const
float
quant_min
,
const
float
quant_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
const
bool
symmetric
,
cudaStream_t
cuda_stream
);
void
CalFakeQuantPerLayer
(
const
float
*
input
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
const
float
*
scale
,
cudaStream_t
cuda_stream
);
void
CalFakeQuantPerLayerGrad
(
const
float
*
input
,
const
float
*
gradient
,
float
*
output
,
const
int
size
,
const
float
*
nudge_min
,
const
float
*
nudge_max
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_FAKE_QUANT_PERLAYER_H_
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cu
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <thrust/extrema.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
#include <thrust/reduce.h>
#include <thrust/pair.h>
#include "minmax_update_impl.cuh"
#include "device/gpu/cuda_common.h"
__global__
void
UpdateInputMinMaxPerLayerWithEMA
(
const
float
*
input_min
,
const
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
float
min
,
const
float
max
,
const
float
decay
)
{
output_min
[
0
]
=
decay
*
(
min
)
+
(
1
-
decay
)
*
(
input_min
[
0
]);
output_min
[
0
]
=
input_min
[
0
]
>
0
?
0
:
input_min
[
0
];
output_max
[
0
]
=
decay
*
(
max
)
+
(
1
-
decay
)
*
(
input_max
[
0
]);
output_max
[
0
]
=
input_max
[
0
]
<
0
?
0
:
input_max
[
0
];
return
;
}
__global__
void
UpdateInputMinMaxPerLayer
(
float
*
output_min
,
float
*
output_max
,
const
float
min
,
const
float
max
)
{
output_min
[
0
]
=
min
>
0
?
0
:
min
;
output_max
[
0
]
=
max
<
0
?
0
:
max
;
return
;
}
__global__
void
UpdateInputMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
int
channels
,
int
per_channel_nums
,
bool
ema
,
float
ema_decay
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
channels
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
thrust
::
pair
<
float
*
,
float
*>
sum
=
thrust
::
minmax_element
(
thrust
::
device
,
input
+
i
*
per_channel_nums
,
input
+
per_channel_nums
*
(
i
+
1
));
if
(
ema
)
{
output_min
[
i
]
=
ema_decay
*
sum
.
first
[
0
]
+
(
1
-
ema_decay
)
*
input_min
[
i
];
output_max
[
i
]
=
ema_decay
*
sum
.
second
[
0
]
+
(
1
-
ema_decay
)
*
input_max
[
i
];
}
else
{
output_min
[
i
]
=
sum
.
first
[
0
];
output_max
[
i
]
=
sum
.
second
[
0
];
}
output_min
[
i
]
=
input_min
[
i
]
>
0
?
0
:
input_min
[
i
];
output_max
[
i
]
=
input_max
[
i
]
<
0
?
0
:
input_max
[
i
];
}
return
;
}
void
CalMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
int
channel_num
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
)
{
int
per_channel_num
=
total_num
/
channel_num
;
UpdateInputMinMaxPerChannel
<<<
GET_BLOCKS
(
channel_num
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
channel_num
,
per_channel_num
,
ema
,
ema_decay
);
return
;
}
void
CalMinMaxPerLayer
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
)
{
float
minel
=
0.
f
;
float
maxel
=
0.
f
;
auto
policy
=
thrust
::
cuda
::
par
.
on
(
cuda_stream
);
thrust
::
pair
<
thrust
::
device_ptr
<
float
>
,
thrust
::
device_ptr
<
float
>>
tuple
;
tuple
=
thrust
::
minmax_element
(
policy
,
thrust
::
device_pointer_cast
(
input
),
thrust
::
device_pointer_cast
(
input
)
+
total_num
);
minel
=
tuple
.
first
[
0
];
maxel
=
tuple
.
second
[
0
];
if
(
ema
)
{
UpdateInputMinMaxPerLayerWithEMA
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
input_min
,
input_max
,
output_min
,
output_max
,
minel
,
maxel
,
ema_decay
);
}
else
{
UpdateInputMinMaxPerLayer
<<<
1
,
1
,
0
,
cuda_stream
>>>
(
output_min
,
output_max
,
minel
,
maxel
);
}
return
;
}
mindspore/ccsrc/kernel/gpu/cuda_impl/minmax_update_impl.cuh
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
#include "device/gpu/cuda_common.h"
void
CalMinMaxPerChannel
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
total_num
,
const
int
channel_num
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
);
void
CalMinMaxPerLayer
(
float
*
input
,
float
*
input_min
,
float
*
input_max
,
float
*
output_min
,
float
*
output_max
,
const
int
size
,
const
float
ema_decay
,
const
bool
ema
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_MIN_MAX_UPDATE_IMPL_H_
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per
_
channel_gpu_kernel.cc
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.cc
浏览文件 @
684ecac9
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per
_
channel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per
_
channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
...
...
@@ -25,21 +25,15 @@ namespace mindspore {
namespace
kernel
{
FakeQuantPerChannelGpuKernel
::
FakeQuantPerChannelGpuKernel
()
:
input_size_
(
0
),
min_size_
(
0
),
max_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
num_channels_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
quant_delay_
(
0
),
ema_
(
false
),
ema_decay_
(
0
),
global_step_
(
0
),
training_
(
false
),
channel_out_
(
0
),
symmetric_
(
false
),
narrow_range_
(
false
),
symmetric_
(
false
)
{}
quant_delay_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
global_step_
(
0
)
{}
const
std
::
vector
<
size_t
>
&
FakeQuantPerChannelGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
...
...
@@ -60,91 +54,57 @@ bool FakeQuantPerChannelGpuKernel::Init(const CNodePtr &kernel_node) {
return
false
;
}
// get attribute
num_bits_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"num_bits"
));
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
1.0
-
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
training_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"training"
));
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
quant_delay_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"quant_delay"
));
if
(
num_bits_
<=
2
||
num_bits_
>=
16
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
num_bits
\'
"
<<
num_bits_
<<
"is out of range, expected between 2 and 16."
;
return
false
;
}
quant_delay_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"quant_delay"
));
if
(
quant_delay_
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
quant_delay
\'
"
<<
num_bits_
<<
" is less then 0, require larger than 0."
;
return
false
;
}
training_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"training"
));
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
if
(
symmetric_
)
{
quant_min_
=
0
-
(
1
<<
(
num_bits_
-
1
));
quant_max_
=
(
1
<<
(
num_bits_
-
1
))
-
1
;
}
else
{
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
}
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
// quant min and max value
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
// shape info for gpu
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
channel_out_
=
SizeToInt
(
input_shape
[
0
]);
min_size_
=
sizeof
(
float
)
*
channel_out_
;
max_size_
=
sizeof
(
float
)
*
channel_out_
;
num_channels_
=
SizeToInt
(
input_shape
[
0
]);
input_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
output_size_
=
input_size_
;
InitSizeLists
();
return
true
;
}
void
FakeQuantPerChannelGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// input in tensor
input_size_list_
.
push_back
(
min_size_
);
// min one scalar
input_size_list_
.
push_back
(
max_size_
);
// max on scalar
output_size_list_
.
push_back
(
output_size_
);
// output in tensor
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out_
);
// scale in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out_
);
// min in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out_
);
// max in channel
}
void
FakeQuantPerChannelGpuKernel
::
CalFakeQuantizeForTraining
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
void
*
stream_ptr
)
{
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel
(
input
,
input_min
,
input_max
,
input_size_
/
sizeof
(
float
),
channel_out_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
// control flow for quant_delay
if
(
global_step_
>=
quant_delay_
)
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
input_size_
/
sizeof
(
float
),
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output
,
input
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"Copy gpu memory failed."
);
}
global_step_
++
;
input_size_list_
.
push_back
(
input_size_
);
// input in tensor
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// min one scalar
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// max on scalar
output_size_list_
.
push_back
(
input_size_
);
// output in tensor
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// scale in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// min in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// max in channel
}
void
FakeQuantPerChannelGpuKernel
::
CalFakeQuantizeForInfer
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
void
*
stream_ptr
)
{
// real launch
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizePerChannel
(
input
,
output
,
input_size_
/
sizeof
(
float
),
channel_out_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
void
FakeQuantPerChannelGpuKernel
::
CalFakeQuantize
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
void
*
stream_ptr
)
{
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
nudge_min
,
nudge_max
,
scale
,
num_channels_
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantPerChannel
(
input
,
output
,
input_size_
/
sizeof
(
float
),
num_channels_
,
nudge_min
,
nudge_max
,
scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
bool
FakeQuantPerChannelGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
...
...
@@ -155,9 +115,9 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
0
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
float
*
d_
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
d_
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
d_
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
float
*
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGpuKernel input is null."
;
...
...
@@ -167,9 +127,16 @@ bool FakeQuantPerChannelGpuKernel::Launch(const std::vector<AddressPtr> &inputs,
}
if
(
training_
)
{
CalFakeQuantizeForTraining
(
input
,
output
,
input_min
,
input_max
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
stream_ptr
);
if
(
global_step_
>=
quant_delay_
)
{
CalFakeQuantize
(
input
,
output
,
input_min
,
input_max
,
nudge_min
,
nudge_max
,
scale
,
stream_ptr
);
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output
,
input
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"Copy gpu memory failed."
);
}
global_step_
++
;
}
else
{
CalFakeQuantize
ForInfer
(
input
,
output
,
input_min
,
input_max
,
d_nudge_min
,
d_nudge_max
,
d_
scale
,
stream_ptr
);
CalFakeQuantize
(
input
,
output
,
input_min
,
input_max
,
nudge_min
,
nudge_max
,
scale
,
stream_ptr
);
}
return
true
;
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per
_
channel_gpu_kernel.h
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_gpu_kernel.h
浏览文件 @
684ecac9
...
...
@@ -39,31 +39,23 @@ class FakeQuantPerChannelGpuKernel : public GpuKernel {
void
InitSizeLists
()
override
;
private:
void
CalFakeQuantizeForTraining
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
void
*
stream_ptr
);
void
CalFakeQuantizeForInfer
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
d_nudge_min
,
float
*
d_nudge_max
,
float
*
d_scale
,
void
*
stream_ptr
);
void
CalFakeQuantize
(
float
*
input
,
float
*
output
,
float
*
input_min
,
float
*
input_max
,
float
*
nudge_min
,
float
*
nudge_max
,
float
*
scale
,
void
*
stream_ptr
);
size_t
input_size_
;
size_t
min_size_
;
size_t
max_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
num_channels_
;
int
num_bits_
;
bool
training_
;
bool
symmetric_
;
bool
narrow_range_
;
int
quant_delay_
;
float
quant_min_
;
float
quant_max_
;
int
quant_delay_
;
bool
ema_
;
float
ema_decay_
;
int
global_step_
;
bool
training_
;
int
channel_out_
;
bool
narrow_range_
;
bool
symmetric_
;
};
}
// namespace kernel
}
// namespace mindspore
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per
_
channel_grad_gpu_kernel.cc
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.cc
浏览文件 @
684ecac9
...
...
@@ -14,21 +14,17 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_per
_
channel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_per
_
channel_impl.cuh"
#include "kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_perchannel_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
FakeQuantPerChannelGradGpuKernel
::
FakeQuantPerChannelGradGpuKernel
()
:
input_size_
(
0
),
min_size_
(
0
),
max_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
channel_out
_
(
0
),
num_channels
_
(
0
),
quant_delay_
(
0
),
global_step_
(
0
),
narrow_range_
(
false
),
...
...
@@ -64,42 +60,34 @@ bool FakeQuantPerChannelGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
if
(
symmetric_
)
{
quant_min_
=
0
-
(
1
<<
(
num_bits_
-
1
));
quant_max_
=
(
1
<<
(
num_bits_
-
1
))
-
1
;
}
else
{
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
}
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
// quant min and max value
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
channel_out_
=
SizeToInt
(
input_shape
[
0
]);
min_size_
=
sizeof
(
float
)
*
channel_out_
;
max_size_
=
sizeof
(
float
)
*
channel_out_
;
num_channels_
=
SizeToInt
(
input_shape
[
0
]);
input_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
output_size_
=
input_size_
;
InitSizeLists
();
return
true
;
}
void
FakeQuantPerChannelGradGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// gradient
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
min_size_
);
// min
input_size_list_
.
push_back
(
max_size_
);
// max
output_size_list_
.
push_back
(
output_size_
);
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out
_
);
// scale in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out
_
);
// min in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
channel_out
_
);
// max in channel
input_size_list_
.
push_back
(
input_size_
);
// gradient
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// min
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// max
output_size_list_
.
push_back
(
input_size_
);
// output
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels
_
);
// scale in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels
_
);
// min in channel
workspace_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels
_
);
// max in channel
}
bool
FakeQuantPerChannelGradGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
...
...
@@ -111,9 +99,9 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
3
);
float
*
d_
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
d_
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
d_
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
float
*
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
if
(
gradient
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerChannelGradGpuKernel gradient is null"
;
...
...
@@ -130,10 +118,10 @@ bool FakeQuantPerChannelGradGpuKernel::Launch(const std::vector<AddressPtr> &inp
int
total_size
=
input_size_
/
sizeof
(
float
);
if
(
global_step_
>=
quant_delay_
)
{
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
channel_out
_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
izePerChannelGrad
(
input
,
gradient
,
output
,
total_size
,
channel_out_
,
d_nudge_min
,
d_
nudge_max
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalNudgePerChannel
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
nudge_min
,
nudge_max
,
scale
,
num_channels
_
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
PerChannelGrad
(
input
,
gradient
,
output
,
total_size
,
num_channels_
,
nudge_min
,
nudge_max
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output
,
gradient
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_per
_
channel_grad_gpu_kernel.h
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_perchannel_grad_gpu_kernel.h
浏览文件 @
684ecac9
...
...
@@ -40,10 +40,6 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
private:
size_t
input_size_
;
size_t
min_size_
;
size_t
max_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
...
...
@@ -51,7 +47,7 @@ class FakeQuantPerChannelGradGpuKernel : public GpuKernel {
int
num_bits_
;
float
quant_min_
;
float
quant_max_
;
int
channel_out
_
;
int
num_channels
_
;
int
quant_delay_
;
int
global_step_
;
bool
narrow_range_
;
...
...
mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.cc
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_
perlayer_
gpu_kernel.cc
浏览文件 @
684ecac9
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_
perlayer_
gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_
perlayer_
impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
...
...
@@ -23,31 +23,25 @@
namespace
mindspore
{
namespace
kernel
{
FakeQuant
GpuKernel
::
FakeQuant
GpuKernel
()
FakeQuant
PerLayerGpuKernel
::
FakeQuantPerLayer
GpuKernel
()
:
input_size_
(
0
),
min_size_
(
0
),
max_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
quant_num_
(
0
),
quant_delay_
(
0
),
ema_
(
false
),
ema_decay_
(
0
),
quant_num_
(
1
),
global_step_
(
0
),
num_bits_
(
0
),
quant_delay_
(
0
),
training_
(
false
),
narrow_range_
(
false
),
symmetric_
(
false
)
{}
const
std
::
vector
<
size_t
>
&
FakeQuantGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuantGpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuantGpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
FakeQuantGpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
FakeQuant
PerLayer
GpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
3
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input number is "
<<
input_num
<<
", but FakeQuant GpuKernel OP needs 3 output."
;
...
...
@@ -59,96 +53,74 @@ bool FakeQuantGpuKernel::Init(const CNodePtr &kernel_node) {
}
num_bits_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"num_bits"
));
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
quant_delay_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"quant_delay"
));
training_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"training"
));
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
if
(
num_bits_
<=
2
||
num_bits_
>=
16
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
num_bits
\'
"
<<
num_bits_
<<
" is out of range, expected between 2 and 16."
;
}
quant_delay_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"quant_delay"
));
if
(
quant_delay_
<
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Attr
\'
quant_delay
\'
"
<<
num_bits_
<<
"is less then 0, require larger than 0."
;
}
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
if
(
symmetric_
)
{
quant_min_
=
0
-
(
1
<<
(
num_bits_
-
1
));
quant_max_
=
(
1
<<
(
num_bits_
-
1
))
-
1
;
}
else
{
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
}
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
// quant min and max value
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
if
(
quant_num_
==
0
)
{
quant_num_
=
1
;
}
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
quant_num_
*=
SizeToInt
(
input_shape
[
i
]);
}
input_size_
=
sizeof
(
float
);
min_size_
=
sizeof
(
float
);
max_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
output_size_
=
input_size_
;
InitSizeLists
();
return
true
;
}
void
FakeQuantGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
min_size_
);
// min
input_size_list_
.
push_back
(
max_size_
);
// max
output_size_list_
.
push_back
(
output_size_
);
workspace_size_list_
.
push_back
(
workspace_size_
);
void
FakeQuantPerLayerGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// x
input_size_list_
.
push_back
(
sizeof
(
float
));
// min
input_size_list_
.
push_back
(
sizeof
(
float
));
// max
output_size_list_
.
push_back
(
input_size_
);
// y
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// scale
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// nudge_min
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// nudge_max
}
bool
FakeQuantGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
bool
FakeQuant
PerLayer
GpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
float
*
output
=
GetDeviceAddress
<
float
>
(
outputs
,
0
);
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
0
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
float
*
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantGpuKernel input x is null."
;
}
if
(
input_min
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantGpuKernel input min is null."
;
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerLayerGpuKernel input x is null."
;
}
if
(
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuant
GpuKernel
input max is null."
;
if
(
input_m
in
==
nullptr
||
input_m
ax
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuant
PerLayerGpuKernel input min or
input max is null."
;
}
// Allocate space for device copies
int
size
=
sizeof
(
float
);
float
*
d_scale
=
nullptr
;
float
*
d_nudge_min
=
nullptr
;
float
*
d_nudge_max
=
nullptr
;
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_scale
),
size
),
"Malloc gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_nudge_min
),
size
),
"Malloc gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_nudge_max
),
size
),
"Malloc gpu memory failed"
);
if
(
training_
)
{
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMax
(
input
,
input_min
,
input_max
,
quant_num_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
// control flow for quant_delay
if
(
global_step_
>=
quant_delay_
)
{
// real launch
CalNudge
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
ize
(
input
,
output
,
quant_num_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalNudge
PerLayer
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
nudge_min
,
nudge_max
,
scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
PerLayer
(
input
,
output
,
quant_num_
,
nudge_min
,
nudge_max
,
scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output
,
input
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
...
...
@@ -157,20 +129,15 @@ bool FakeQuantGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const std
global_step_
++
;
}
else
{
// real launch
CalNudge
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
ize
(
input
,
output
,
quant_num_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalNudge
PerLayer
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
nudge_min
,
nudge_max
,
scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuant
PerLayer
(
input
,
output
,
quant_num_
,
nudge_min
,
nudge_max
,
scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
// Cleanup
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_scale
),
"Free gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_nudge_min
),
"Free gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_nudge_max
),
"Free gpu memory failed"
);
return
true
;
}
MS_REG_GPU_KERNEL
(
FakeQuantPerLayer
,
FakeQuantGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuantPerLayer
,
FakeQuant
PerLayer
GpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/fake_quant_gpu_kernel.h
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_
perlayer_
gpu_kernel.h
浏览文件 @
684ecac9
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
...
...
@@ -23,10 +23,10 @@
namespace
mindspore
{
namespace
kernel
{
class
FakeQuantGpuKernel
:
public
GpuKernel
{
class
FakeQuant
PerLayer
GpuKernel
:
public
GpuKernel
{
public:
FakeQuantGpuKernel
();
~
FakeQuantGpuKernel
()
=
default
;
FakeQuant
PerLayer
GpuKernel
();
~
FakeQuant
PerLayer
GpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
...
...
@@ -40,22 +40,16 @@ class FakeQuantGpuKernel : public GpuKernel {
private:
size_t
input_size_
;
size_t
min_size_
;
size_t
max_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
num_bits_
;
float
quant_min_
;
float
quant_max_
;
int
quant_num_
;
int
quant_delay_
;
bool
ema_
;
float
ema_decay_
;
int
global_step_
;
int
num_bits_
;
int
quant_delay_
;
bool
training_
;
bool
narrow_range_
;
bool
symmetric_
;
...
...
@@ -63,4 +57,4 @@ class FakeQuantGpuKernel : public GpuKernel {
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GPUKERNEL_H_
mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.cc
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_
perlayer_
grad_gpu_kernel.cc
浏览文件 @
684ecac9
...
...
@@ -14,33 +14,30 @@
* limitations under the License.
*/
#include "kernel/gpu/quant/fake_quant_grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_impl.cuh"
#include "kernel/gpu/quant/fake_quant_
perlayer_
grad_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/fake_quant_
perlayer_
impl.cuh"
namespace
mindspore
{
namespace
kernel
{
FakeQuant
GradGpuKernel
::
FakeQuant
GradGpuKernel
()
FakeQuant
PerLayerGradGpuKernel
::
FakeQuantPerLayer
GradGpuKernel
()
:
input_size_
(
0
),
min_size_
(
0
),
max_size_
(
0
),
output_size_
(
0
),
workspace_size_
(
0
),
num_bits_
(
0
),
quant_min_
(
0
),
quant_max_
(
0
),
quant_
size_
(
0
),
quant_
num_
(
1
),
quant_delay_
(
0
),
global_step_
(
0
),
narrow_range_
(
false
),
symmetric_
(
false
)
{}
const
std
::
vector
<
size_t
>
&
FakeQuantGradGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GradGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuantGradGpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GradGpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuantGradGpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
const
std
::
vector
<
size_t
>
&
FakeQuant
PerLayer
GradGpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
FakeQuantGradGpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
bool
FakeQuant
PerLayer
GradGpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
4
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input number is "
<<
input_num
<<
", but FakeQuantGrad GpuKernel OP needs 4 output."
;
...
...
@@ -62,87 +59,66 @@ bool FakeQuantGradGpuKernel::Init(const CNodePtr &kernel_node) {
}
symmetric_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"symmetric"
));
if
(
symmetric_
)
{
quant_min_
=
0
-
(
1
<<
(
num_bits_
-
1
));
quant_max_
=
(
1
<<
(
num_bits_
-
1
))
-
1
;
}
else
{
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
}
narrow_range_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"narrow_range"
));
// quant min and max value
quant_min_
=
0
;
quant_max_
=
(
1
<<
num_bits_
)
-
1
;
if
(
narrow_range_
)
{
quant_min_
++
;
}
if
(
quant_size_
==
0
)
{
quant_size_
=
1
;
}
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
quant_
size
_
*=
SizeToInt
(
input_shape
[
i
]);
quant_
num
_
*=
SizeToInt
(
input_shape
[
i
]);
}
input_size_
=
sizeof
(
float
);
min_size_
=
sizeof
(
float
);
max_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
output_size_
=
input_size_
;
InitSizeLists
();
return
true
;
}
void
FakeQuantGradGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// gradient
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
min_size_
);
// min
input_size_list_
.
push_back
(
max_size_
);
// max
output_size_list_
.
push_back
(
output_size_
);
void
FakeQuantPerLayerGradGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// gradient
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
sizeof
(
float
));
// min
input_size_list_
.
push_back
(
sizeof
(
float
));
// max
output_size_list_
.
push_back
(
input_size_
);
// output
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// scale
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// nudge_min
workspace_size_list_
.
push_back
(
sizeof
(
float
));
// nudge_max
}
bool
FakeQuantGradGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
bool
FakeQuantPerLayerGradGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
float
*
output
=
GetDeviceAddress
<
float
>
(
outputs
,
0
);
float
*
gradient
=
GetDeviceAddress
<
float
>
(
inputs
,
0
);
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
3
);
float
*
scale
=
GetDeviceAddress
<
float
>
(
workspace
,
0
);
float
*
nudge_min
=
GetDeviceAddress
<
float
>
(
workspace
,
1
);
float
*
nudge_max
=
GetDeviceAddress
<
float
>
(
workspace
,
2
);
if
(
gradient
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantGradGpuKernel gradient is null"
;
MS_LOG
(
EXCEPTION
)
<<
"FakeQuant
PerLayer
GradGpuKernel gradient is null"
;
}
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantGradGpuKernel input is null."
;
}
if
(
input_min
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantGradGpuKernel input min is null."
;
MS_LOG
(
EXCEPTION
)
<<
"FakeQuantPerLayerGradGpuKernel input is null."
;
}
if
(
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuant
GradGpuKernel input
max is null."
;
if
(
input_m
in
==
nullptr
||
input_m
ax
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"FakeQuant
PerLayerGradGpuKernel input min or
max is null."
;
}
if
(
global_step_
>=
quant_delay_
)
{
float
*
d_scale
=
nullptr
;
float
*
d_nudge_min
=
nullptr
;
float
*
d_nudge_max
=
nullptr
;
int
size
=
sizeof
(
float
);
// Allocate space for device copies
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_scale
),
size
),
"Malloc gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_nudge_min
),
size
),
"Malloc gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaMalloc
(
reinterpret_cast
<
void
**>
(
&
d_nudge_max
),
size
),
"Malloc gpu memory failed"
);
CalNudge
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
d_nudge_min
,
d_nudge_max
,
d_scale
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantizeGrad
(
input
,
gradient
,
output
,
quant_size_
,
d_nudge_min
,
d_nudge_max
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
// Cleanup
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_scale
),
"Free gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_nudge_min
),
"Free gpu memory failed"
);
CHECK_CUDA_RET_WITH_ERROR
(
cudaFree
(
d_nudge_max
),
"Free gpu memory failed"
);
CalNudgePerLayer
(
input_min
,
input_max
,
quant_min_
,
quant_max_
,
nudge_min
,
nudge_max
,
scale
,
symmetric_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
CalFakeQuantPerLayerGrad
(
input
,
gradient
,
output
,
quant_num_
,
nudge_min
,
nudge_max
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
}
else
{
CHECK_CUDA_RET_WITH_ERROR
(
cudaMemcpyAsync
(
output
,
gradient
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
...
...
@@ -152,6 +128,6 @@ bool FakeQuantGradGpuKernel::Launch(const std::vector<AddressPtr> &inputs, const
return
true
;
}
MS_REG_GPU_KERNEL
(
FakeQuantPerLayerGrad
,
FakeQuantGradGpuKernel
)
MS_REG_GPU_KERNEL
(
FakeQuantPerLayerGrad
,
FakeQuant
PerLayer
GradGpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/fake_quant_grad_gpu_kernel.h
→
mindspore/ccsrc/kernel/gpu/quant/fake_quant_
perlayer_
grad_gpu_kernel.h
浏览文件 @
684ecac9
...
...
@@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GRAD_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GRAD_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
...
...
@@ -23,10 +23,10 @@
namespace
mindspore
{
namespace
kernel
{
class
FakeQuantGradGpuKernel
:
public
GpuKernel
{
class
FakeQuant
PerLayer
GradGpuKernel
:
public
GpuKernel
{
public:
FakeQuantGradGpuKernel
();
~
FakeQuantGradGpuKernel
()
=
default
;
FakeQuant
PerLayer
GradGpuKernel
();
~
FakeQuant
PerLayer
GradGpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
...
...
@@ -40,9 +40,6 @@ class FakeQuantGradGpuKernel : public GpuKernel {
private:
size_t
input_size_
;
size_t
min_size_
;
size_t
max_size_
;
size_t
output_size_
;
size_t
workspace_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
...
...
@@ -51,7 +48,7 @@ class FakeQuantGradGpuKernel : public GpuKernel {
int
num_bits_
;
float
quant_min_
;
float
quant_max_
;
int
quant_
size
_
;
int
quant_
num
_
;
int
quant_delay_
;
int
global_step_
;
bool
narrow_range_
;
...
...
@@ -60,4 +57,4 @@ class FakeQuantGradGpuKernel : public GpuKernel {
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_GRAD_GPUKERNEL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_FAKEQUANT_
PERLAYER_
GRAD_GPUKERNEL_H_
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.cc
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace
mindspore
{
namespace
kernel
{
MinMaxUpdatePerChannelGpuKernel
::
MinMaxUpdatePerChannelGpuKernel
()
:
input_size_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
),
num_channels_
(
0
)
{}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerChannelGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerChannelGpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerChannelGpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
MinMaxUpdatePerChannelGpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
3
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input number is "
<<
input_num
<<
", but FakeQuant GpuKernel OP needs 3 output."
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Output number is "
<<
output_num
<<
", but FakeQuant GpuKernel OP needs 1 output."
;
}
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
num_channels_
=
SizeToInt
(
input_shape
[
0
]);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
quant_num_
*=
SizeToInt
(
input_shape
[
i
]);
}
input_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
InitSizeLists
();
return
true
;
}
void
MinMaxUpdatePerChannelGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// min
input_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// max
output_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// output min
output_size_list_
.
push_back
(
sizeof
(
float
)
*
num_channels_
);
// output max
}
bool
MinMaxUpdatePerChannelGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
float
*
output_min
=
GetDeviceAddress
<
float
>
(
outputs
,
0
);
float
*
output_max
=
GetDeviceAddress
<
float
>
(
outputs
,
1
);
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
0
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"MinMaxUpdatePerChannelGpuKernel input x is null."
;
}
if
(
input_min
==
nullptr
||
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"MinMaxUpdatePerChannelGpuKernel input min or input max is null."
;
}
// calculate the input min and max according by the parameter ema and ema_decay.
CalMinMaxPerChannel
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
input_size_
/
sizeof
(
float
),
num_channels_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
MS_REG_GPU_KERNEL
(
MinMaxUpdatePerChannel
,
MinMaxUpdatePerChannelGpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perchannel_gpu_kernel.h
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace
mindspore
{
namespace
kernel
{
class
MinMaxUpdatePerChannelGpuKernel
:
public
GpuKernel
{
public:
MinMaxUpdatePerChannelGpuKernel
();
~
MinMaxUpdatePerChannelGpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
;
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
;
bool
Init
(
const
CNodePtr
&
kernel
)
override
;
protected:
void
InitSizeLists
()
override
;
private:
size_t
input_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
quant_num_
;
bool
ema_
;
float
ema_decay_
;
int
num_channels_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERCHANNEL_GPUKERNEL_H_
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.cc
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h"
#include "kernel/gpu/cuda_impl/minmax_update_impl.cuh"
#include <thrust/extrema.h>
#include <thrust/pair.h>
#include <thrust/device_vector.h>
#include <cuda_runtime_api.h>
namespace
mindspore
{
namespace
kernel
{
MinMaxUpdatePerLayerGpuKernel
::
MinMaxUpdatePerLayerGpuKernel
()
:
input_size_
(
0
),
quant_num_
(
1
),
ema_
(
false
),
ema_decay_
(
0
)
{}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerLayerGpuKernel
::
GetInputSizeList
()
const
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerLayerGpuKernel
::
GetOutputSizeList
()
const
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
MinMaxUpdatePerLayerGpuKernel
::
GetWorkspaceSizeList
()
const
{
return
workspace_size_list_
;
}
bool
MinMaxUpdatePerLayerGpuKernel
::
Init
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
3
)
{
MS_LOG
(
EXCEPTION
)
<<
"Input number is "
<<
input_num
<<
", but FakeQuant GpuKernel OP needs 3 output."
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
2
)
{
MS_LOG
(
EXCEPTION
)
<<
"Output number is "
<<
output_num
<<
", but FakeQuant GpuKernel OP needs 1 output."
;
}
ema_
=
GetValue
<
bool
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema"
));
ema_decay_
=
GetValue
<
float
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"ema_decay"
));
// init size
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
++
i
)
{
quant_num_
*=
SizeToInt
(
input_shape
[
i
]);
}
input_size_
=
sizeof
(
float
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
InitSizeLists
();
return
true
;
}
void
MinMaxUpdatePerLayerGpuKernel
::
InitSizeLists
()
{
input_size_list_
.
push_back
(
input_size_
);
// input
input_size_list_
.
push_back
(
sizeof
(
float
));
// input min
input_size_list_
.
push_back
(
sizeof
(
float
));
// input max
output_size_list_
.
push_back
(
sizeof
(
float
));
// output min
output_size_list_
.
push_back
(
sizeof
(
float
));
// output max
}
bool
MinMaxUpdatePerLayerGpuKernel
::
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
{
float
*
output_min
=
GetDeviceAddress
<
float
>
(
outputs
,
0
);
float
*
output_max
=
GetDeviceAddress
<
float
>
(
outputs
,
1
);
float
*
input
=
GetDeviceAddress
<
float
>
(
inputs
,
0
);
float
*
input_min
=
GetDeviceAddress
<
float
>
(
inputs
,
1
);
float
*
input_max
=
GetDeviceAddress
<
float
>
(
inputs
,
2
);
if
(
input
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"MinMaxUpdatePerLayerGpuKernel input x is null."
;
}
if
(
input_min
==
nullptr
||
input_max
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"MinMaxUpdatePerLayerGpuKernel input min or input max is null."
;
}
CalMinMaxPerLayer
(
input
,
input_min
,
input_max
,
output_min
,
output_max
,
quant_num_
,
ema_decay_
,
ema_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
MS_REG_GPU_KERNEL
(
MinMaxUpdatePerLayer
,
MinMaxUpdatePerLayerGpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/quant/minmax_update_perlayer_gpu_kernel.h
0 → 100644
浏览文件 @
684ecac9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace
mindspore
{
namespace
kernel
{
class
MinMaxUpdatePerLayerGpuKernel
:
public
GpuKernel
{
public:
MinMaxUpdatePerLayerGpuKernel
();
~
MinMaxUpdatePerLayerGpuKernel
()
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
;
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
;
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
workspace
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
;
bool
Init
(
const
CNodePtr
&
kernel
)
override
;
protected:
void
InitSizeLists
()
override
;
private:
size_t
input_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
int
quant_num_
;
bool
ema_
;
float
ema_decay_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_MINMAX_UPDATE_PERLAYER_GPUKERNEL_H_
mindspore/nn/layer/quant.py
浏览文件 @
684ecac9
...
...
@@ -324,6 +324,7 @@ class FakeQuantWithMinMax(Cell):
validator
.
check_type
(
"min_init"
,
min_init
,
[
int
,
float
])
validator
.
check_type
(
"max_init"
,
max_init
,
[
int
,
float
])
validator
.
check
(
"min_init"
,
min_init
,
"max_init"
,
max_init
,
rel
=
Rel
.
LT
)
validator
.
check_integer
(
'quant_delay'
,
quant_delay
,
0
,
Rel
.
GE
)
self
.
min_init
=
min_init
self
.
max_init
=
max_init
self
.
num_bits
=
num_bits
...
...
mindspore/ops/operations/_quant_ops.py
浏览文件 @
684ecac9
...
...
@@ -106,7 +106,7 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
Args:
ema (bool): Use EMA algorithm update value min and max. Default: False.
ema_decay (int) : EMA algorithm decay parameter. Default: 0.999.
channel_axis (int): Quantization by channel axis
, support 0 and
1. Default: 1.
channel_axis (int): Quantization by channel axis
. Ascend backend only supports 0 or
1. Default: 1.
Inputs:
- **x** (Tensor) : float32 Tensor representing the shape of the output tensor.
...
...
@@ -123,12 +123,13 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
>>> output_tensor = MinMaxUpdatePerChannel(num_bits=8)(x, min, max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_x_rank
=
[
2
,
4
]
ascend_
support_x_rank
=
[
2
,
4
]
@
prim_attr_register
def
__init__
(
self
,
ema
=
False
,
ema_decay
=
0.999
,
channel_axis
=
1
):
"""init FakeQuantPerChannelUpdate OP for Ascend"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
is_ascend
=
context
.
get_context
(
'device_target'
)
==
"Ascend"
if
self
.
is_ascend
:
from
mindspore.ops._op_impl._custom_op
import
minmax_update_perchannel
if
ema
and
not
ema_decay
:
raise
ValueError
(
...
...
@@ -137,15 +138,18 @@ class MinMaxUpdatePerChannel(PrimitiveWithInfer):
self
.
ema
=
validator
.
check_value_type
(
'ema'
,
ema
,
(
bool
,),
self
.
name
)
self
.
ema_decay
=
validator
.
check_number_range
(
'ema_decay'
,
ema_decay
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
if
self
.
is_ascend
:
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
else
:
self
.
channel_axis
=
validator
.
check_integer
(
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'min_up'
,
'max_up'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
if
len
(
x_shape
)
not
in
self
.
support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
support_x_rank
}
'"
)
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GT
,
self
.
name
)
if
self
.
is_ascend
and
len
(
x_shape
)
not
in
self
.
ascend_support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
ascend_support_x_rank
}
'"
)
if
not
self
.
is_ascend
:
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
len
(
...
...
@@ -317,7 +321,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
symmetric (bool): Quantization algorithm use symmetric or not. Default: False.
narrow_range (bool): Quantization algorithm use narrow range or not. Default: False.
training (bool): Training the network or not. Default: True.
channel_axis (int): Quantization by channel axis
, support 0 and
1. Default: 1.
channel_axis (int): Quantization by channel axis
. Ascend backend only supports 0 or
1. Default: 1.
Inputs:
- **x** (Tensor) : 4-D float32 Tensor representing the shape of the output tensor.
...
...
@@ -335,7 +339,7 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
>>> result = fake_quant(input_x, _min, _max)
"""
support_quant_bit
=
[
4
,
7
,
8
]
support_x_rank
=
[
2
,
4
]
ascend_
support_x_rank
=
[
2
,
4
]
@
prim_attr_register
def
__init__
(
self
,
...
...
@@ -348,7 +352,8 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
training
=
True
,
channel_axis
=
1
):
"""init FakeQuantPerChannel OP"""
if
context
.
get_context
(
'device_target'
)
==
"Ascend"
:
self
.
is_ascend
=
context
.
get_context
(
'device_target'
)
==
"Ascend"
if
self
.
is_ascend
:
from
mindspore.ops._op_impl._custom_op
import
fake_quant_perchannel
if
num_bits
not
in
self
.
support_quant_bit
:
raise
ValueError
(
...
...
@@ -370,13 +375,17 @@ class FakeQuantPerChannel(PrimitiveWithInfer):
'num_bits'
,
num_bits
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
quant_delay
=
validator
.
check_integer
(
'quant_delay'
,
quant_delay
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
if
self
.
is_ascend
:
self
.
channel_axis
=
validator
.
check_int_range
(
'channel_axis'
,
channel_axis
,
0
,
1
,
Rel
.
INC_BOTH
,
self
.
name
)
else
:
self
.
channel_axis
=
validator
.
check_integer
(
'channel_axis'
,
channel_axis
,
0
,
Rel
.
GE
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'x'
,
'min'
,
'max'
],
outputs
=
[
'out'
])
def
infer_shape
(
self
,
x_shape
,
min_shape
,
max_shape
):
if
len
(
x_shape
)
not
in
self
.
support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
support_x_rank
}
'"
)
if
self
.
is_ascend
and
len
(
x_shape
)
not
in
self
.
ascend_support_x_rank
:
raise
ValueError
(
f
"For '
{
self
.
name
}
' x rank should be in '
{
self
.
ascend_support_x_rank
}
'"
)
if
not
self
.
is_ascend
:
validator
.
check_integer
(
"x rank"
,
len
(
x_shape
),
1
,
Rel
.
GE
,
self
.
name
)
validator
.
check
(
"min shape"
,
min_shape
,
"max shape"
,
max_shape
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
"min shape"
,
min_shape
[
0
],
x_shape
[
self
.
channel_axis
],
Rel
.
EQ
,
self
.
name
)
...
...
mindspore/train/quant/quant.py
浏览文件 @
684ecac9
...
...
@@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
aware quantization
."""
"""
quantization aware
."""
import
copy
import
re
import
numpy
as
np
import
mindspore.context
as
context
from
...
import
log
as
logger
from
...
import
nn
,
ops
...
...
@@ -32,6 +33,7 @@ from ...ops.operations import _inner_ops as inner
from
...train
import
serialization
from
.
import
quant_utils
_ACTIVATION_MAP
=
{
nn
.
ReLU
:
quant
.
ReLUQuant
,
nn
.
ReLU6
:
quant
.
ReLU6Quant
,
nn
.
HSigmoid
:
quant
.
HSigmoidQuant
,
...
...
@@ -46,7 +48,7 @@ class _AddFakeQuantInput(nn.Cell):
def
__init__
(
self
,
network
,
quant_delay
=
0
):
super
(
_AddFakeQuantInput
,
self
).
__init__
(
auto_prefix
=
False
)
self
.
fake_quant_input
=
quant
.
FakeQuantWithMinMax
(
min_init
=-
6
,
max_init
=
6
,
quant_delay
=
quant_delay
,
ema
=
True
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input'
)
self
.
fake_quant_input
.
update_parameters_name
(
'fake_quant_input
.
'
)
self
.
network
=
network
def
construct
(
self
,
data
):
...
...
@@ -165,8 +167,8 @@ class ConvertToQuantNetwork:
convert Conv2d cell to quant cell
"""
conv_inner
=
subcell
.
conv
bn_inner
=
subcell
.
batchnorm
if
subcell
.
has_bn
and
self
.
bn_fold
:
bn_inner
=
subcell
.
batchnorm
conv_inner
=
quant
.
Conv2dBatchNormQuant
(
conv_inner
.
in_channels
,
conv_inner
.
out_channels
,
kernel_size
=
conv_inner
.
kernel_size
,
...
...
@@ -421,26 +423,26 @@ def convert_quant_network(network,
Args:
network (Cell): Obtain a pipeline through network for saving graph summary.
quant_delay (int): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: [0, 0]
bn_fold (bool): Flag to used bn fold ops for simulation inference operation. Default: False.
freeze_bn (int): Number of steps after which BatchNorm OP parameters used total mean and variance. Default: 0.
num_bits (list of int): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: [8, 8]
per_channel (list of bool): Quantization granularity based on layer or on channel. If `True`
quant_delay (int, list or tuple): Number of steps after which weights and activations are quantized during
eval. The first element represent weights and second element represent data flow. Default: (0, 0)
num_bits (int, list or tuple): Number of bits to use for quantizing weights and activations. The first
element represent weights and second element represent data flow. Default: (8, 8)
per_channel (bool, list or tuple): Quantization granularity based on layer or on channel. If `True`
then base on per channel otherwise base on per layer. The first element represent weights
and second element represent data flow. Default:
[False, False]
symmetric (
list of bool
): Quantization algorithm use symmetric or not. If `True` then base on
and second element represent data flow. Default:
(False, False)
symmetric (
bool, list or tuple
): Quantization algorithm use symmetric or not. If `True` then base on
symmetric otherwise base on asymmetric. The first element represent weights and second
element represent data flow. Default:
[False, False]
narrow_range (
list of bool
): Quantization algorithm use narrow range or not. If `True` then base
element represent data flow. Default:
(False, False)
narrow_range (
bool, list or tuple
): Quantization algorithm use narrow range or not. If `True` then base
on narrow range otherwise base on off narrow range. The first element represent weights and
second element represent data flow. Default:
[False, False]
second element represent data flow. Default:
(False, False)
Returns:
Cell, Network which has change to
aware quantization
training network cell.
Cell, Network which has change to
quantization aware
training network cell.
"""
support_device
=
[
"Ascend"
,
"GPU"
]
def
convert2list
(
name
,
value
):
if
not
isinstance
(
value
,
list
)
and
not
isinstance
(
value
,
tuple
):
value
=
[
value
]
...
...
@@ -454,6 +456,9 @@ def convert_quant_network(network,
symmetric
=
convert2list
(
"symmetric"
,
symmetric
)
narrow_range
=
convert2list
(
"narrow range"
,
narrow_range
)
if
context
.
get_context
(
'device_target'
)
not
in
support_device
:
raise
KeyError
(
"Not support {} backend."
.
format
(
context
.
get_context
(
'device_target'
)))
net
=
ConvertToQuantNetwork
(
network
=
network
,
quant_delay
=
quant_delay
,
bn_fold
=
bn_fold
,
...
...
tests/st/ops/gpu/test_fake_quant_perchannel.py
0 → 100644
浏览文件 @
684ecac9
此差异已折叠。
点击以展开。
tests/st/ops/gpu/test_fake_quant_perchannel_grad.py
0 → 100644
浏览文件 @
684ecac9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.ops.operations
import
_quant_ops
as
Q
context
.
set_context
(
device_target
=
'GPU'
,
device_id
=
0
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
num_bits
=
8
,
narrow_range
=
False
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
Q
.
FakeQuantPerChannelGrad
(
num_bits
=
num_bits
,
narrow_range
=
narrow_range
)
def
construct
(
self
,
dout
,
x
,
minq
,
maxq
):
return
self
.
op
(
dout
,
x
,
minq
,
maxq
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad1
():
# WithVarsPerChannelDim1GradientNudgedDown_ZeroMinAndMax
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
]).
astype
(
'float32'
)
x
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
expect
=
dout
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad2
():
# WithVarsPerChannelDim1GradientNudgedDown_RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
63.75
,
63.8
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.65
,
63.65
,
63.65
,
63.65
]).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad3
():
# WithVarsPerChannelDim1GradientNudgedDown_NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
63.5
,
63.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.4
,
63.4
,
63.4
,
63.4
]).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad4
():
# WithVarsPerChannelDim1GradientNudgedUp_RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
63.5
,
63.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
,
63.625
,
63.625
,
63.625
]).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad5
():
# WithVarsPerChannelDim1GradientNudgedUp_NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
63.25
,
63.3
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
,
63.375
,
63.375
,
63.375
]).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad6
():
# WithVarsPerChannelDim2GradientNudgedDown_RegularRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
3
,
2
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.25
,
63.75
,
63.8
]
).
reshape
(
3
,
2
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.65
,
63.65
,
63.65
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad7
():
# WithVarsPerChannelDim2GradientNudgedDown_NarrowRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
3
,
2
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.25
,
63.5
,
63.6
]
).
reshape
(
3
,
2
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.4
,
63.4
,
63.4
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad8
():
# WithVarsPerChannelDim2GradientNudgedUp_RegularRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
3
,
2
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
-
0.2
,
0.0
,
63.5
,
63.6
]
).
reshape
(
3
,
2
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
,
63.625
,
63.625
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad9
():
# WithVarsPerChannelDim2GradientNudgedUp_NarrowRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
3
,
2
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
-
0.2
,
0.0
,
63.25
,
63.3
]
).
reshape
(
3
,
2
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
,
63.375
,
63.375
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad10
():
# WithVarsPerChannelDim4GradientNudgedDown_RegularRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
,
3
,
2
,
1
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
63.75
,
63.8
,
-
0.1
,
0.0
,
63.75
,
63.8
,
-
0.1
,
0.0
,
63.75
,
63.8
,
-
0.1
,
0.0
,
63.75
,
63.8
,
-
0.1
,
0.0
,
63.75
,
63.8
,
-
0.1
,
0.0
,
63.75
,
63.8
]).
reshape
(
4
,
3
,
2
,
1
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.65
,
63.65
,
63.65
,
63.65
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
,
0.0
,
dout
[
5
],
dout
[
6
],
0.0
,
0.0
,
dout
[
9
],
dout
[
10
],
0.0
,
0.0
,
dout
[
13
],
dout
[
14
],
0.0
,
0.0
,
dout
[
17
],
dout
[
18
],
0.0
,
0.0
,
dout
[
21
],
dout
[
22
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad11
():
# WithVarsPerChannelDim4GradientNudgedDown_NarrowRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
,
3
,
2
,
1
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.1
,
0.0
,
63.5
,
63.6
,
-
0.1
,
0.0
,
63.5
,
63.6
,
-
0.1
,
0.0
,
63.5
,
63.6
,
-
0.1
,
0.0
,
63.5
,
63.6
,
-
0.1
,
0.0
,
63.5
,
63.6
,
-
0.1
,
0.0
,
63.5
,
63.6
]).
reshape
(
4
,
3
,
2
,
1
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
,
-
0.1
,
-
0.1
,
-
0.1
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.4
,
63.4
,
63.4
,
63.4
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
,
0.0
,
dout
[
5
],
dout
[
6
],
0.0
,
0.0
,
dout
[
9
],
dout
[
10
],
0.0
,
0.0
,
dout
[
13
],
dout
[
14
],
0.0
,
0.0
,
dout
[
17
],
dout
[
18
],
0.0
,
0.0
,
dout
[
21
],
dout
[
22
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad12
():
# WithVarsPerChannelDim4GradientNudgedUp_RegularRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
,
3
,
2
,
1
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
63.5
,
63.6
,
-
0.3
,
-
0.25
,
63.5
,
63.6
,
-
0.3
,
-
0.25
,
63.5
,
63.6
,
-
0.3
,
-
0.25
,
63.5
,
63.6
,
-
0.3
,
-
0.25
,
63.5
,
63.6
,
-
0.3
,
-
0.25
,
63.5
,
63.6
]).
reshape
(
4
,
3
,
2
,
1
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
,
63.625
,
63.625
,
63.625
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
,
0.0
,
dout
[
5
],
dout
[
6
],
0.0
,
0.0
,
dout
[
9
],
dout
[
10
],
0.0
,
0.0
,
dout
[
13
],
dout
[
14
],
0.0
,
0.0
,
dout
[
17
],
dout
[
18
],
0.0
,
0.0
,
dout
[
21
],
dout
[
22
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad13
():
# WithVarsPerChannelDim4GradientNudgedUp_NarrowRange
read_dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
4
,
3
,
2
,
1
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.3
,
-
0.25
,
63.25
,
63.3
,
-
0.3
,
-
0.25
,
63.25
,
63.3
,
-
0.3
,
-
0.25
,
63.25
,
63.3
,
-
0.3
,
-
0.25
,
63.25
,
63.3
,
-
0.3
,
-
0.25
,
63.25
,
63.3
,
-
0.3
,
-
0.25
,
63.25
,
63.3
]).
reshape
(
4
,
3
,
2
,
1
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
,
-
0.125
,
-
0.125
,
-
0.125
]).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
,
63.375
,
63.375
,
63.375
]).
astype
(
np
.
float32
)
dout
=
read_dout
.
flatten
()
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
0.0
,
0.0
,
dout
[
5
],
dout
[
6
],
0.0
,
0.0
,
dout
[
9
],
dout
[
10
],
0.0
,
0.0
,
dout
[
13
],
dout
[
14
],
0.0
,
0.0
,
dout
[
17
],
dout
[
18
],
0.0
,
0.0
,
dout
[
21
],
dout
[
22
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
read_dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"="
*
40
)
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
tests/st/ops/gpu/test_fake_quant_perlayer.py
0 → 100644
浏览文件 @
684ecac9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
from
mindspore.common.tensor
import
Tensor
import
mindspore.nn
as
nn
from
mindspore.ops.operations
import
_quant_ops
as
Q
context
.
set_context
(
device_target
=
'GPU'
,
device_id
=
0
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
num_bits
=
8
,
quant_delay
=
0
,
symmetric
=
False
,
narrow_range
=
False
,
training
=
True
):
super
(
Net
,
self
).
__init__
()
self
.
fake_quant
=
Q
.
FakeQuantPerLayer
(
num_bits
=
num_bits
,
quant_delay
=
quant_delay
,
symmetric
=
symmetric
,
narrow_range
=
narrow_range
,
training
=
training
)
def
construct
(
self
,
x
,
minq
,
maxq
):
return
self
.
fake_quant
(
x
,
minq
,
maxq
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant1
():
# (8, false, 0.0f, 0.0f, TensorShape({2, 3}),
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f},
# {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f});
x
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant2
():
# 8, false, -10.0f, 53.75f, TensorShape({2, 3}),
# {-10.1f, -10.0f, -9.9f, -9.75f, 53.75f, 53.8f},
# {-10.0f, -10.0f, -10.0f, -9.75f, 53.75f, 53.75f});
x
=
np
.
array
([
-
10.1
,
-
10.0
,
-
9.9
,
-
9.75
,
53.75
,
53.8
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
10.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
53.75
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
10.0
,
-
10.0
,
-
10.0
,
-
9.75
,
53.75
,
53.75
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant3
():
# WithVarsNoNudging_NarrowRange
x
=
np
.
array
([
-
10.1
,
-
10.0
,
-
9.90
,
-
9.75
,
53.5
,
53.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
10.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
53.5
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
10.0
,
-
10.0
,
-
10.0
,
-
9.75
,
53.5
,
53.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant4
():
# WithVarsNudgedDown_RegularRange
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.25
,
63.75
,
63.8
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.65
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.0
,
0.0
,
0.0
,
0.25
,
63.75
,
63.75
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant5
():
# WithVarsNudgedDown_NarrowRange
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.25
,
63.5
,
63.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.0
,
0.0
,
0.0
,
0.25
,
63.5
,
63.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant6
():
# WithVarsNudgedUp_RegularRange
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.5
,
63.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.25
,
-
0.25
,
-
0.25
,
0.0
,
63.5
,
63.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant7
():
# WithVarsNudgedUp_NarrowRange
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.25
,
63.3
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.25
,
-
0.25
,
-
0.25
,
0.0
,
63.25
,
63.25
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant8
():
# WithVarsNudgedZeroIs255_RegularRange
x
=
np
.
array
([
-
63.80
,
-
63.75
,
-
63.70
,
-
63.5
,
0.0
,
0.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
63.65
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
63.75
,
-
63.75
,
-
63.75
,
-
63.5
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant9
():
# WithVarsNudgedZeroIs255_NarrowRange
x
=
np
.
array
([
-
63.6
,
-
63.5
,
-
63.4
,
-
63.25
,
0.0
,
0.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
63.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
63.5
,
-
63.5
,
-
63.5
,
-
63.25
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant10
():
# WithVarsNoNudging_4Bits_RegularRange
x
=
np
.
array
([
-
6.1
,
-
6.0
,
-
5.9
,
-
5.5
,
1.5
,
1.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
6.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
1.5
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
6.0
,
-
6.0
,
-
6.0
,
-
5.5
,
1.5
,
1.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant11
():
# WithVarsNoNudging_4Bits_NarrowRange
x
=
np
.
array
([
-
6.1
,
-
6.0
,
-
5.9
,
-
5.5
,
1.0
,
1.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
6.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
1.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
6.0
,
-
6.0
,
-
6.0
,
-
5.5
,
1.0
,
1.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant12
():
# WithVarsNudgedDown_4Bits_RegularRange
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.5
,
7.5
,
7.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
7.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.0
,
0.0
,
0.0
,
0.5
,
7.5
,
7.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant13
():
# WithVarsNudgedDown_4Bits_NarrowRange
x
=
np
.
array
([
-
0.1
,
0.0
,
0.1
,
0.5
,
7.0
,
7.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
6.9
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.0
,
0.0
,
0.0
,
0.5
,
7.0
,
7.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant14
():
# WithVarsNudgedUp_4Bits_RegularRange
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.24
,
0.0
,
7.0
,
7.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
7.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.5
,
-
0.5
,
-
0.00
,
0.0
,
7.0
,
7.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant15
():
# WithVarsNudgedUp_4Bits_NarrowRange
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.24
,
0.0
,
6.5
,
6.6
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
6.6
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
0.5
,
-
0.5
,
-
0.00
,
0.0
,
6.5
,
6.5
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant16
():
# WithVarsNudgedZero15_4Bits_RegularRange
x
=
np
.
array
([
-
7.6
,
-
7.5
,
-
7.4
,
-
7.2
,
0.0
,
0.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
7.3
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.2
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
7.5
,
-
7.5
,
-
7.5
,
-
7.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant17
():
# WithVarsNudgedZero15_4Bits_NarrowRange
x
=
np
.
array
([
-
7.1
,
-
7.0
,
-
6.9
,
-
6.5
,
0.0
,
0.1
]).
reshape
(
2
,
3
).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
6.8
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.2
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
-
7.0
,
-
7.0
,
-
7.0
,
-
6.5
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
tests/st/ops/gpu/test_fake_quant_perlayer_grad.py
0 → 100644
浏览文件 @
684ecac9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
from
mindspore
import
Tensor
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.ops.operations
import
_quant_ops
as
Q
context
.
set_context
(
device_target
=
'GPU'
,
device_id
=
0
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
num_bits
=
8
,
narrow_range
=
False
):
super
(
Net
,
self
).
__init__
()
self
.
op
=
Q
.
FakeQuantPerLayerGrad
(
num_bits
=
num_bits
,
narrow_range
=
narrow_range
)
def
construct
(
self
,
dout
,
x
,
minq
,
maxq
):
return
self
.
op
(
dout
,
x
,
minq
,
maxq
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad1
():
# WithArgsGradient RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.5
,
63.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad2
():
# WithArgsGradient NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.25
,
63.3
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad3
():
# WithArgsGradient_4Bits_RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.4
,
0.0
,
7.0
,
7.1
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
7.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad4
():
# WithArgsGradient_4Bits_NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.4
,
0.0
,
6.5
,
6.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
6.6
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad5
():
# FakeQuantWithMinMaxVarsGradient
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
0.0
,
0.0
,
0.0
,
0.0
,
0.0
,
0.0
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
0.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
0.0
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
dout
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad6
():
# WithVarsGradient_RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.5
,
63.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.625
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad7
():
# WithVarsGradient_NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.26
,
-
0.25
,
-
0.24
,
0.0
,
63.25
,
63.3
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.125
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
63.375
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
8
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad8
():
# WithVarsGradient_4Bits_RegularRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.4
,
0.0
,
7.0
,
7.1
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
7.1
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
False
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_fake_quant_grad9
():
# WithVarsGradient_4Bits_NarrowRange
dout
=
np
.
random
.
uniform
(
-
1
,
1
,
size
=
[
6
]).
astype
(
'float32'
)
x
=
np
.
array
([
-
0.6
,
-
0.5
,
-
0.4
,
0.0
,
6.5
,
6.6
]).
astype
(
np
.
float32
)
min_val
=
np
.
array
([
-
0.4
]).
reshape
(
1
).
astype
(
np
.
float32
)
max_val
=
np
.
array
([
6.6
]).
reshape
(
1
).
astype
(
np
.
float32
)
expect
=
np
.
array
([
0.0
,
dout
[
1
],
dout
[
2
],
dout
[
3
],
dout
[
4
],
0.0
]).
astype
(
np
.
float32
)
net
=
Net
(
num_bits
=
4
,
narrow_range
=
True
)
output
=
net
(
Tensor
(
dout
),
Tensor
(
x
),
Tensor
(
min_val
),
Tensor
(
max_val
))
error
=
np
.
ones
(
shape
=
expect
.
shape
)
*
1.0e-5
diff
=
output
.
asnumpy
().
flatten
()
-
expect
print
(
"output: "
,
output
)
print
(
"expect: "
,
expect
)
assert
np
.
all
(
np
.
abs
(
diff
)
<
error
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录