Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8991e9ae
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
8991e9ae
编写于
3月 23, 2022
作者:
W
whs
提交者:
GitHub
3月 23, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix quant and dequant cuda kernels when quant_axis==1 (#40772)
上级
319f95d0
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
46 deletion
+62
-46
paddle/fluid/operators/fake_dequantize_op.cu
paddle/fluid/operators/fake_dequantize_op.cu
+29
-21
paddle/fluid/operators/fake_quantize_op.cu
paddle/fluid/operators/fake_quantize_op.cu
+33
-25
未找到文件。
paddle/fluid/operators/fake_dequantize_op.cu
浏览文件 @
8991e9ae
...
...
@@ -58,19 +58,15 @@ __global__ void DequantizeOneScaleQuantAxis0(const T* in, const T* scale,
}
template
<
typename
T
>
__global__
void
DequantizeOneScaleQuantAxis1
(
const
T
*
in
,
const
T
*
scale
,
T
max_range
,
const
int
num
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
int
bid
=
blockIdx
.
x
;
T
s
=
scale
[
bid
%
cout
];
int
wh_size
=
num
/
(
cin
*
cout
);
const
T
*
in_current
=
in
+
bid
*
wh_size
;
T
*
out_current
=
out
+
bid
*
wh_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
out_current
[
i
]
=
in_current
[
i
]
*
s
/
max_range
;
__global__
void
DequantizeOneScaleQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
T
max_range
,
const
int64_t
num
,
const
int
n_scales
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
num
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
n_scales
];
out
[
i
]
=
in
[
i
]
*
s
/
max_range
;
}
}
...
...
@@ -98,20 +94,32 @@ struct ChannelDequantizeFunctor<platform::CUDADeviceContext, T> {
const
T
*
in_data
=
in
->
data
<
T
>
();
T
*
out_data
=
out
->
mutable_data
<
T
>
(
dev_ctx
.
GetPlace
());
if
(
scale_num
==
1
)
{
int
num
=
in
->
numel
();
int
64_t
num
=
in
->
numel
();
const
T
*
scale_factor
=
scales
[
0
]
->
data
<
T
>
();
if
(
quant_axis
==
0
)
{
int
grid
=
in_dims
[
0
];
int
block
=
1024
;
DequantizeOneScaleQuantAxis0
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
0
],
out_data
);
}
else
if
(
quant_axis
==
1
)
{
// Dequantize weight of Cin * Cout * W * H
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
DequantizeOneScaleQuantAxis1
<
T
><<<
grid
,
block
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
dev_ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
dev_ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(
((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
DequantizeOneScaleQuantAxisN
<
T
><<<
grid_size
,
block_size
,
0
,
dev_ctx
.
stream
()
>>>
(
in_data
,
scale_factor
,
max_range
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
else
if
(
scale_num
==
2
)
{
// Not need to consider quant_axis
...
...
paddle/fluid/operators/fake_quantize_op.cu
浏览文件 @
8991e9ae
...
...
@@ -273,18 +273,18 @@ struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis0
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
c
,
T
*
out
)
{
const
int
64_t
n
,
const
int
c
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
;
int
channel_size
=
n
/
c
;
int
64_t
channel_size
=
n
/
c
;
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
for
(
int
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
for
(
int
64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
...
...
@@ -293,25 +293,20 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
}
}
// ChannelClipAndQuantKernel for quant_axis is
1
// ChannelClipAndQuantKernel for quant_axis is
N
template
<
typename
T
>
__global__
void
ChannelClipAndQuantKernelQuantAxis1
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int
n
,
const
int
cin
,
const
int
cout
,
T
*
out
)
{
T
s
=
scale
[
blockIdx
.
x
%
cout
];
T
inv_s
=
inverse
(
s
);
int
wh_size
=
n
/
(
cin
*
cout
);
const
T
*
in_c
=
in
+
blockIdx
.
x
*
wh_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
wh_size
;
for
(
int
i
=
threadIdx
.
x
;
i
<
wh_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
];
__global__
void
ChannelClipAndQuantKernelQuantAxisN
(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
for
(
int64_t
i
=
idx
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
nScale
];
T
inv_s
=
1.0
/
s
;
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt
*
inv_s
*
v
;
out
_c
[
i
]
=
round
(
v
);
out
[
i
]
=
round
(
v
);
}
}
...
...
@@ -327,7 +322,7 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
"the received is %d"
,
quant_axis
));
int
num
=
in
.
numel
();
int
64_t
num
=
in
.
numel
();
auto
in_dims
=
in
.
dims
();
const
T
*
in_data
=
in
.
data
<
T
>
();
const
T
*
scale_data
=
scale
.
data
<
T
>
();
...
...
@@ -338,11 +333,24 @@ struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
int
block
=
1024
;
ChannelClipAndQuantKernelQuantAxis0
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
out_data
);
}
else
if
(
quant_axis
==
1
)
{
int
grid
=
in_dims
[
0
]
*
in_dims
[
1
];
int
block
=
1024
;
ChannelClipAndQuantKernelQuantAxis1
<
T
><<<
grid
,
block
,
0
,
ctx
.
stream
()
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
0
],
in_dims
[
1
],
out_data
);
}
else
{
int
quant_stride
=
1
;
for
(
int
i
=
quant_axis
+
1
;
i
<
in_dims
.
size
();
i
++
)
{
quant_stride
*=
in_dims
[
i
];
}
int64_t
block_size
=
std
::
min
(
num
,
static_cast
<
int64_t
>
(
ctx
.
GetMaxThreadsPerBlock
()
/
4
));
int64_t
max_threads
=
ctx
.
GetMaxPhysicalThreadCount
();
// SM * block_per_SM
const
int64_t
max_blocks
=
std
::
max
(((
max_threads
-
1
)
/
block_size
+
1
),
static_cast
<
int64_t
>
(
1
));
const
int64_t
grid_size
=
std
::
min
(
max_blocks
,
(
num
+
block_size
-
1
)
/
block_size
);
ChannelClipAndQuantKernelQuantAxisN
<
T
><<<
grid_size
,
block_size
>>>
(
in_data
,
scale_data
,
bin_cnt
,
num
,
in_dims
[
quant_axis
],
quant_stride
,
out_data
);
}
}
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录