Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
85e531a9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2310
Star
20933
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
85e531a9
编写于
7月 22, 2021
作者:
C
cc
提交者:
GitHub
7月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add int16 kernel for lookup_talbe and dequantize_abs_max op (#34275)
* add int16 kernel for lookup_talbe and dequantize_abs_max op
上级
5179853a
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
146 addition
and
4 deletion
+146
-4
paddle/fluid/operators/dequantize_abs_max_op.cc
paddle/fluid/operators/dequantize_abs_max_op.cc
+4
-2
paddle/fluid/operators/dequantize_abs_max_op.cu
paddle/fluid/operators/dequantize_abs_max_op.cu
+3
-1
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+1
-0
paddle/fluid/operators/lookup_table_op.cu
paddle/fluid/operators/lookup_table_op.cu
+2
-1
paddle/fluid/operators/lookup_table_op.h
paddle/fluid/operators/lookup_table_op.h
+2
-0
paddle/fluid/operators/math/blas_impl.h
paddle/fluid/operators/math/blas_impl.h
+9
-0
python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py
...addle/fluid/tests/unittests/test_dequantize_abs_max_op.py
+7
-0
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
+118
-0
未找到文件。
paddle/fluid/operators/dequantize_abs_max_op.cc
浏览文件 @
85e531a9
...
...
@@ -50,6 +50,7 @@ struct DequantizeFunctor<platform::CPUDeviceContext, T> {
};
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
int8_t
>;
template
struct
DequantizeFunctor
<
platform
::
CPUDeviceContext
,
int16_t
>;
class
DequantizeMaxAbsOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
@@ -79,7 +80,7 @@ class DequantizeMaxAbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(
int8 Tensor) The input with int8
type is the "
"(
Int Tensor) The input with int8/16
type is the "
"low precision tensor."
);
AddInput
(
"Scale"
,
"(float) The scale in quantization stage."
);
AddOutput
(
"Out"
,
...
...
@@ -108,4 +109,5 @@ REGISTER_OPERATOR(
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
dequantize_abs_max
,
ops
::
DequantizeMaxAbsKernel
<
CPU
,
int8_t
>
);
ops
::
DequantizeMaxAbsKernel
<
CPU
,
int8_t
>
,
ops
::
DequantizeMaxAbsKernel
<
CPU
,
int16_t
>
);
paddle/fluid/operators/dequantize_abs_max_op.cu
浏览文件 @
85e531a9
...
...
@@ -45,6 +45,7 @@ struct DequantizeFunctor<platform::CUDADeviceContext, T> {
};
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
int8_t
>;
template
struct
DequantizeFunctor
<
platform
::
CUDADeviceContext
,
int16_t
>;
}
// namespace operators
}
// namespace paddle
...
...
@@ -52,4 +53,5 @@ template struct DequantizeFunctor<platform::CUDADeviceContext, int8_t>;
namespace
ops
=
paddle
::
operators
;
using
CUDA
=
paddle
::
platform
::
CUDADeviceContext
;
REGISTER_OP_CUDA_KERNEL
(
dequantize_abs_max
,
ops
::
DequantizeMaxAbsKernel
<
CUDA
,
int8_t
>
);
ops
::
DequantizeMaxAbsKernel
<
CUDA
,
int8_t
>
,
ops
::
DequantizeMaxAbsKernel
<
CUDA
,
int16_t
>
);
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
85e531a9
...
...
@@ -229,6 +229,7 @@ REGISTER_OPERATOR(lookup_table_grad, ops::LookupTableOpGrad,
REGISTER_OP_CPU_KERNEL
(
lookup_table
,
ops
::
LookupTableKernel
<
float
>
,
ops
::
LookupTableKernel
<
double
>
,
ops
::
LookupTableKernel
<
int8_t
>
,
ops
::
LookupTableKernel
<
int16_t
>
,
ops
::
LookupTableKernel
<
paddle
::
platform
::
bfloat16
>
);
REGISTER_OP_CPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradKernel
<
float
>
,
ops
::
LookupTableGradKernel
<
double
>
,
...
...
paddle/fluid/operators/lookup_table_op.cu
浏览文件 @
85e531a9
...
...
@@ -227,7 +227,8 @@ namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL
(
lookup_table
,
ops
::
LookupTableCUDAKernel
<
float
>
,
ops
::
LookupTableCUDAKernel
<
double
>
,
ops
::
LookupTableCUDAKernel
<
plat
::
float16
>
,
ops
::
LookupTableCUDAKernel
<
int8_t
>
);
ops
::
LookupTableCUDAKernel
<
int8_t
>
,
ops
::
LookupTableCUDAKernel
<
int16_t
>
);
REGISTER_OP_CUDA_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradCUDAKernel
<
float
>
,
ops
::
LookupTableGradCUDAKernel
<
double
>
,
...
...
paddle/fluid/operators/lookup_table_op.h
浏览文件 @
85e531a9
...
...
@@ -103,6 +103,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
if
(
id_index
!=
-
1
)
{
if
(
input_data_type
==
framework
::
proto
::
VarType
::
INT8
||
input_data_type
==
framework
::
proto
::
VarType
::
INT16
||
input_data_type
==
framework
::
proto
::
VarType
::
BF16
)
{
memcpy
(
output
+
i
*
row_width
,
table
+
id_index
*
row_width
,
row_width
*
sizeof
(
T
));
...
...
@@ -130,6 +131,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
id_index
));
if
(
input_data_type
==
framework
::
proto
::
VarType
::
INT8
||
input_data_type
==
framework
::
proto
::
VarType
::
INT16
||
input_data_type
==
framework
::
proto
::
VarType
::
BF16
)
{
memcpy
(
output
+
i
*
row_width
,
table
+
id_index
*
row_width
,
row_width
*
sizeof
(
T
));
...
...
paddle/fluid/operators/math/blas_impl.h
浏览文件 @
85e531a9
...
...
@@ -54,6 +54,15 @@ struct CBlas<int8_t> {
}
};
template
<
>
struct
CBlas
<
int16_t
>
{
template
<
typename
...
ARGS
>
static
void
VCOPY
(
ARGS
...
args
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Blas VCOPY do not supported on CPU, please check your code"
));
}
};
template
<
>
struct
CBlas
<
platform
::
bfloat16
>
{
template
<
typename
...
ARGS
>
...
...
python/paddle/fluid/tests/unittests/test_dequantize_abs_max_op.py
浏览文件 @
85e531a9
...
...
@@ -62,5 +62,12 @@ class TestDequantizeMaxAbsOp5Bits(TestDequantizeMaxAbsOp):
self
.
data_type
=
"int8"
class
TestDequantizeMaxAbsOpInt16
(
TestDequantizeMaxAbsOp
):
def
set_args
(
self
):
self
.
num_bits
=
16
self
.
max_range
=
math
.
pow
(
2
,
self
.
num_bits
-
1
)
-
1
self
.
data_type
=
"int16"
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_lookup_table_op.py
浏览文件 @
85e531a9
...
...
@@ -316,6 +316,124 @@ class TestLookupTableWithTensorIdsWIsSelectedRowsInt8(
assert
(
row
==
result_array
[
idx
]).
all
()
@
skip_check_grad_ci
(
reason
=
"Int16 type only be used in test and inference."
)
class
TestLookupTableOpInt16
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"lookup_table"
table
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
(
17
,
31
)).
astype
(
"int16"
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
"int64"
)
ids_expand
=
np
.
expand_dims
(
ids
,
axis
=
1
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids_expand
}
self
.
outputs
=
{
'Out'
:
table
[
ids
]}
def
test_check_output
(
self
):
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"Int16 type only be used in test and inference."
)
class
TestLookupTableOpWithTensorIdsInt16
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"lookup_table"
table
=
np
.
random
.
randint
(
low
=-
128
,
high
=
127
,
size
=
(
17
,
31
)).
astype
(
"int16"
)
ids
=
np
.
random
.
randint
(
low
=
0
,
high
=
17
,
size
=
(
2
,
4
,
5
,
1
)).
astype
(
"int64"
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
self
.
outputs
=
{
'Out'
:
table
[
ids
.
flatten
()].
reshape
((
2
,
4
,
5
,
31
))}
def
test_check_output
(
self
):
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"Int16 type only be used in test and inference."
)
class
TestLookupTableOpWithPaddingInt16
(
TestLookupTableOpInt16
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
int
(
padding_idx
)}
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"Int16 type only be used in test and inference."
)
class
TestLookupTableOpWithTensorIdsAndPaddingInt16
(
TestLookupTableOpWithTensorIdsInt16
):
def
test_check_output
(
self
):
ids
=
self
.
inputs
[
'Ids'
]
flatten_idx
=
ids
.
flatten
()
padding_idx
=
np
.
random
.
choice
(
flatten_idx
,
1
)[
0
]
self
.
outputs
[
'Out'
][
np
.
squeeze
(
ids
==
padding_idx
)]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
cpt
.
long_type
(
padding_idx
)}
self
.
check_output
()
class
TestLookupTableWIsSelectedRowsInt16
(
unittest
.
TestCase
):
def
prepare_ids
(
self
,
scope
,
place
):
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
array
([[
0
],
[
4
],
[
3
],
[
5
]]).
astype
(
"int64"
)
ids_tensor
.
set
(
ids_array
,
place
)
return
ids_array
def
prepare_w
(
self
,
scope
,
place
):
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
row_numel
=
12
w_selected_rows
=
scope
.
var
(
'W'
).
get_selected_rows
()
w_selected_rows
.
set_height
(
len
(
rows
))
w_selected_rows
.
set_rows
(
rows
)
w_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"int16"
)
for
i
in
range
(
len
(
rows
)):
w_array
[
i
]
*=
i
w_tensor
=
w_selected_rows
.
get_tensor
()
w_tensor
.
set
(
w_array
,
place
)
def
create_out_tensor
(
self
,
scope
,
place
):
return
scope
.
var
(
'Out'
).
get_tensor
()
def
check_result
(
self
,
ids_array
,
result_array
):
for
idx
,
row
in
enumerate
(
ids_array
):
assert
(
row
[
0
]
==
result_array
[
idx
]).
all
()
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
ids_array
=
self
.
prepare_ids
(
scope
,
place
)
self
.
prepare_w
(
scope
,
place
)
out_tensor
=
self
.
create_out_tensor
(
scope
,
place
)
# create and run lookup_table operator
lookup_table
=
Operator
(
"lookup_table"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
)
lookup_table
.
run
(
scope
,
place
)
# get result from Out
result_array
=
np
.
array
(
out_tensor
)
self
.
check_result
(
ids_array
,
result_array
)
def
test_w_is_selected_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
# currently only support CPU
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestLookupTableWithTensorIdsWIsSelectedRowsInt16
(
TestLookupTableWIsSelectedRowsInt16
):
def
prepare_ids
(
self
,
scope
,
place
):
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
random
.
randint
(
low
=
0
,
high
=
6
,
size
=
(
2
,
4
,
3
,
1
)).
astype
(
"int64"
)
ids_tensor
.
set
(
ids_array
,
place
)
return
ids_array
def
check_result
(
self
,
ids_array
,
result_array
):
for
idx
,
row
in
np
.
ndenumerate
(
ids_array
):
assert
(
row
==
result_array
[
idx
]).
all
()
class
TestOutDtype
(
unittest
.
TestCase
):
def
test_dtype
(
self
):
api_fn
=
F
.
embedding
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录