Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
5ba9dfc1
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5ba9dfc1
编写于
3月 09, 2020
作者:
M
mapingshuo
提交者:
GitHub
3月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add lookup_table_dequant_op (#22900)
add lookup_table_dequant_op
上级
a020a257
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
292 addition
and
0 deletion
+292
-0
paddle/fluid/operators/lookup_table_dequant_op.cc
paddle/fluid/operators/lookup_table_dequant_op.cc
+128
-0
paddle/fluid/operators/lookup_table_dequant_op.h
paddle/fluid/operators/lookup_table_dequant_op.h
+109
-0
python/paddle/fluid/tests/unittests/test_lookup_table_dequant_op.py
...dle/fluid/tests/unittests/test_lookup_table_dequant_op.py
+55
-0
未找到文件。
paddle/fluid/operators/lookup_table_dequant_op.cc
0 → 100644
浏览文件 @
5ba9dfc1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/lookup_table_dequant_op.h"
#include <memory>
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
#include "paddle/fluid/framework/var_type_inference.h"
namespace
paddle
{
namespace
operators
{
class
LookupTableDequantOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"W"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(W) of LookupTableDequantOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Ids"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Ids) of LookupTableDequantOp should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Out) of LookupTableDequantOp should not be null."
));
auto
table_dims
=
ctx
->
GetInputDim
(
"W"
);
auto
ids_dims
=
ctx
->
GetInputDim
(
"Ids"
);
int
ids_rank
=
ids_dims
.
size
();
VLOG
(
5
)
<<
"ids rank is "
<<
ids_rank
<<
std
::
endl
;
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: The dimensions of the 'lookup table' must be 2. "
"But received lookup table's dimensions = %d, "
"lookup table's shape = [%s]."
,
table_dims
.
size
(),
table_dims
));
PADDLE_ENFORCE_EQ
(
ids_dims
[
ids_rank
-
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"ShapeError: The last dimensions of the 'Ids' tensor must be 1. "
"But received Ids's last dimensions = %d, Ids's shape = [%s]."
,
ids_dims
[
ids_rank
-
1
],
ids_dims
));
auto
output_dims
=
framework
::
vectorize
(
framework
::
slice_ddim
(
ids_dims
,
0
,
ids_rank
-
1
));
PADDLE_ENFORCE_GE
(
table_dims
[
1
],
2
,
platform
::
errors
::
InvalidArgument
(
"the second dim of table_dims should be "
"greater or equal to 2, but the actual shape "
"is [%s]"
,
table_dims
));
output_dims
.
push_back
((
table_dims
[
1
]
-
2
)
*
4
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
output_dims
));
if
(
ctx
->
GetOutputsVarType
(
"Out"
)[
0
]
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
ctx
->
ShareLoD
(
"Ids"
,
/*->*/
"Out"
);
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
data_type
=
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"W"
);
return
framework
::
OpKernelType
(
data_type
,
ctx
.
device_context
());
}
};
class
LookupTableDequantOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"W"
,
"(Tensor) The input represents embedding tensors, "
"This tensor is a quantized tensor"
);
AddInput
(
"Ids"
,
"An input with type int64 "
"contains the ids to be looked up in W. "
"The last dimension size must be 1."
);
AddOutput
(
"Out"
,
"The lookup results, which have the same type as W."
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids."
)
.
SetDefault
(
kNoPadding
);
AddComment
(
R"DOC(
Lookup Table Dequant Operator.
The `W` input is a quantized parameter for the sake of saving memories.
This operator first index embeddings with `Ids`,
then dequantizes them and contact them as output (`Out`).
The input Ids can carry the LoD (Level of Details) information,
or not. And the output only shares the LoD information with input Ids.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
lookup_table_dequant
,
ops
::
LookupTableDequantOp
,
ops
::
LookupTableDequantOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
paddle
::
framework
::
EmptyGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OP_CPU_KERNEL
(
lookup_table_dequant
,
ops
::
LookupTableDequantKernel
<
float
>
);
paddle/fluid/operators/lookup_table_dequant_op.h
0 → 100644
浏览文件 @
5ba9dfc1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/var_type_traits.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
using
DDim
=
framework
::
DDim
;
template
<
typename
T
>
void
dequant
(
const
unsigned
char
*
in
,
T
*
out
,
float
min
,
float
max
,
int
emb_size
,
int
pow_2_bits
)
{
float
scale
=
(
max
-
min
)
/
pow_2_bits
;
for
(
int
i
=
0
;
i
<
emb_size
;
++
i
)
{
T
x
=
scale
*
static_cast
<
int
>
(
in
[
i
])
+
min
;
out
[
i
]
=
x
;
}
}
constexpr
int64_t
kNoPadding
=
-
1
;
template
<
typename
T
>
class
LookupTableDequantKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
table_var
=
context
.
InputVar
(
"W"
);
auto
id_name
=
context
.
InputNames
(
"Ids"
).
front
();
auto
embedding_name
=
context
.
InputNames
(
"W"
).
front
();
auto
out_name
=
context
.
OutputNames
(
"Out"
).
front
();
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
int64_t
ids_numel
=
ids_t
->
numel
();
PADDLE_ENFORCE_GE
(
table_var
->
Type
(),
framework
::
VarTypeTrait
<
LoDTensor
>::
kId
,
platform
::
errors
::
InvalidArgument
(
"lookup table must be LodTensor"
));
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
int64_t
row_number
=
table_t
->
dims
()[
0
];
int64_t
quant_number
=
table_t
->
dims
()[
1
];
int64_t
row_width
=
(
quant_number
-
2
)
*
4
;
auto
*
table
=
table_t
->
data
<
float
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
pow_2_bits
=
static_cast
<
int
>
(
pow
(
2
,
8
));
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
,
platform
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
row_number
,
ids
[
i
]));
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
platform
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
row_number
,
ids
[
i
]));
float
min
=
*
(
table
+
ids
[
i
]
*
quant_number
);
float
max
=
*
(
table
+
ids
[
i
]
*
quant_number
+
1
);
int
offset
=
ids
[
i
]
*
quant_number
+
2
;
const
unsigned
char
*
tensor_buf
=
reinterpret_cast
<
const
unsigned
char
*>
(
table
+
offset
);
dequant
(
tensor_buf
,
output
+
i
*
row_width
,
min
,
max
,
row_width
,
pow_2_bits
);
}
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_lookup_table_dequant_op.py
0 → 100644
浏览文件 @
5ba9dfc1
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
,
skip_check_grad_ci
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
import
paddle.compat
as
cpt
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
import
struct
class
TestLookupTableDequantOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"lookup_table_dequant"
table
=
np
.
random
.
random
((
17
,
32
)).
astype
(
"float32"
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
).
astype
(
"int64"
)
ids_expand
=
np
.
expand_dims
(
ids
,
axis
=
1
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids_expand
}
# calculate output
output
=
[]
for
id
in
ids
:
tmp
=
[]
min
,
max
=
table
[
id
][
0
],
table
[
id
][
1
]
for
val
in
table
[
id
][
2
:]:
tmp
+=
[
int
(
x
)
*
(
max
-
min
)
/
pow
(
2
,
8
)
+
min
for
x
in
bytearray
(
struct
.
pack
(
"f"
,
val
))
]
output
.
append
(
tmp
)
self
.
outputs
=
{
'Out'
:
np
.
asarray
(
output
,
dtype
=
"float32"
)}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录