Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
0ae8a2d6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0ae8a2d6
编写于
5月 31, 2022
作者:
L
Leo Chen
提交者:
GitHub
5月 31, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix the underflow of fp16 fake quantize operators (#43088)
Co-authored-by:
N
Ryan Jeng
<
rjeng@nvidia.com
>
上级
4700a08e
变更
2
展开全部
隐藏空白更改
内联
并排
Showing
2 changed file
with
263 addition
and
385 deletion
+263
-385
paddle/fluid/operators/fake_quantize_op.cu.h
paddle/fluid/operators/fake_quantize_op.cu.h
+33
-28
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
+230
-357
未找到文件。
paddle/fluid/operators/fake_quantize_op.cu.h
浏览文件 @
0ae8a2d6
...
...
@@ -217,16 +217,18 @@ __global__ void ClipAndQuantKernel(const T* in, const T* scale,
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
s
=
static_cast
<
ComputeDataType
>
(
scale
[
0
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
]
;
T
v
=
x
>
s
?
s
:
x
;
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
])
;
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
...
...
@@ -237,18 +239,19 @@ __global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
int
bid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
;
T
s
=
scale
[
0
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
s
=
static_cast
<
ComputeDataType
>
(
scale
[
0
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
for
(
int
i
=
bid
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
in
[
i
]
;
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
])
;
x
=
x
>
s
?
s
:
x
;
x
=
x
<
-
s
?
-
s
:
x
;
x
=
bin_cnt_t
*
inv_s
*
x
;
x
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
x
)));
out
[
i
]
=
(
x
*
s
)
/
bin_cnt_t
;
x
=
round
(
x
);
out
[
i
]
=
static_cast
<
T
>
((
x
*
s
)
/
bin_cnt_t
);
}
}
...
...
@@ -302,17 +305,18 @@ __global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
const
T
*
in_c
=
in
+
blockIdx
.
x
*
channel_size
;
T
*
out_c
=
out
+
blockIdx
.
x
*
channel_size
;
T
s
=
scale
[
blockIdx
.
x
];
T
inv_s
=
inverse
(
s
);
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
s
=
static_cast
<
ComputeDataType
>
(
scale
[
blockIdx
.
x
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
for
(
int64_t
i
=
tid
;
i
<
channel_size
;
i
+=
blockDim
.
x
)
{
T
x
=
in_c
[
i
]
;
T
v
=
x
>
s
?
s
:
x
;
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in_c
[
i
])
;
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out_c
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
out_c
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
...
...
@@ -322,16 +326,17 @@ __global__ void ChannelClipAndQuantKernelQuantAxisN(
const
T
*
in
,
const
T
*
scale
,
const
int
bin_cnt
,
const
int64_t
n
,
const
int
nScale
,
const
int
quant_stride
,
T
*
out
)
{
int64_t
idx
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
T
bin_cnt_t
=
static_cast
<
T
>
(
bin_cnt
);
using
ComputeDataType
=
typename
QuantizeDataType
<
T
>::
type
;
ComputeDataType
bin_cnt_t
=
static_cast
<
ComputeDataType
>
(
bin_cnt
);
for
(
int64_t
i
=
idx
;
i
<
n
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
s
=
scale
[(
i
/
quant_stride
)
%
nScale
];
T
inv_s
=
inverse
(
s
);
T
x
=
in
[
i
];
T
v
=
x
>
s
?
s
:
x
;
ComputeDataType
s
=
static_cast
<
ComputeDataType
>
(
scale
[(
i
/
quant_stride
)
%
nScale
]);
ComputeDataType
inv_s
=
inverse
(
s
);
ComputeDataType
x
=
static_cast
<
ComputeDataType
>
(
in
[
i
]);
ComputeDataType
v
=
x
>
s
?
s
:
x
;
v
=
v
<
-
s
?
-
s
:
v
;
v
=
bin_cnt_t
*
inv_s
*
v
;
out
[
i
]
=
static_cast
<
T
>
(
round
(
static_cast
<
typename
QuantizeDataType
<
T
>::
type
>
(
v
)));
out
[
i
]
=
static_cast
<
T
>
(
round
(
v
));
}
}
...
...
python/paddle/fluid/tests/unittests/test_fake_quantize_op.py
浏览文件 @
0ae8a2d6
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录