Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
9ae1523e
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9ae1523e
编写于
1月 23, 2018
作者:
C
Cao Ying
提交者:
GitHub
1月 23, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #7719 from guoshengCS/enhance-lookup_table_op-padidx
Enhance lookup_table_op to support padding_idx.
上级
76429f4d
d5120442
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
126 addition
and
21 deletion
+126
-21
paddle/framework/attribute.cc
paddle/framework/attribute.cc
+3
-0
paddle/framework/attribute.h
paddle/framework/attribute.h
+26
-0
paddle/framework/framework.proto
paddle/framework/framework.proto
+2
-0
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+1
-0
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+6
-0
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+26
-7
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+20
-4
paddle/pybind/print_operators_doc.cc
paddle/pybind/print_operators_doc.cc
+2
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+25
-9
python/paddle/v2/fluid/tests/test_lookup_table_op.py
python/paddle/v2/fluid/tests/test_lookup_table_op.py
+14
-0
未找到文件。
paddle/framework/attribute.cc
浏览文件 @
9ae1523e
...
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
...
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
}
return
val
;
return
val
;
}
}
case
proto
::
AttrType
::
LONG
:
{
return
attr_desc
.
l
();
}
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_desc
.
type
());
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_desc
.
type
());
}
}
...
...
paddle/framework/attribute.h
浏览文件 @
9ae1523e
...
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
...
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
const
std
::
string
&
attr_name_
;
const
std
::
string
&
attr_name_
;
};
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
attr
.
type
().
name
());
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// check whether a certain attribute fit its limits
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
// an attribute can have more than one limits
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/framework/framework.proto
浏览文件 @
9ae1523e
...
@@ -26,6 +26,7 @@ enum AttrType {
...
@@ -26,6 +26,7 @@ enum AttrType {
BOOLEAN
=
6
;
BOOLEAN
=
6
;
BOOLEANS
=
7
;
BOOLEANS
=
7
;
BLOCK
=
8
;
BLOCK
=
8
;
LONG
=
9
;
}
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// OpDesc describes an instance of a C++ framework::OperatorBase
...
@@ -44,6 +45,7 @@ message OpDesc {
...
@@ -44,6 +45,7 @@ message OpDesc {
optional
bool
b
=
10
;
optional
bool
b
=
10
;
repeated
bool
bools
=
11
;
repeated
bool
bools
=
11
;
optional
int32
block_idx
=
12
;
optional
int32
block_idx
=
12
;
optional
int64
l
=
13
;
};
};
message
Var
{
message
Var
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
9ae1523e
...
@@ -283,6 +283,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
...
@@ -283,6 +283,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
}
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
};
...
...
paddle/framework/type_defs.h
浏览文件 @
9ae1523e
...
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
...
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using
Attribute
=
using
Attribute
=
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*>
;
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
9ae1523e
...
@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,6 +66,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"(boolean, default false) "
"Sparse update"
)
"Sparse update"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup. "
"Otherwise the given value indicates padding the output "
"with zeros whenever lookup encounters it in Ids."
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Lookup Table Operator.
Lookup Table Operator.
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
9ae1523e
...
@@ -21,9 +21,11 @@ limitations under the License. */
...
@@ -21,9 +21,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
...
@@ -34,8 +36,15 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
...
@@ -34,8 +36,15 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
T
*
out
=
output
+
idy
*
D
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
if
(
PaddingFlag
)
{
if
(
id
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
else
out
[
i
]
=
tab
[
i
];
}
else
{
out
[
i
]
=
tab
[
i
];
out
[
i
]
=
tab
[
i
];
}
}
}
idy
+=
BlockDimY
*
GridDimX
;
idy
+=
BlockDimY
*
GridDimX
;
}
}
}
}
...
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
...
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
if
(
padding_idx
==
-
1
)
LookupTable
<
T
,
128
,
8
,
8
,
false
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
else
LookupTable
<
LookupTable
<
T
,
12
8
,
8
,
T
,
128
,
8
,
8
,
8
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
true
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
}
};
};
...
@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
paddle/operators/lookup_table_op.h
浏览文件 @
9ae1523e
...
@@ -32,18 +32,32 @@ class LookupTableKernel : public framework::OpKernel<T> {
...
@@ -32,18 +32,32 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
// float tensor
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
// float tensor
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
// int tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
// float tensor
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
int
N
=
table_t
->
dims
()[
0
];
int
N
=
table_t
->
dims
()[
0
];
int
D
=
table_t
->
dims
()[
1
];
int
D
=
table_t
->
dims
()[
1
];
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
padding_idx
==
-
1
)
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
}
}
}
}
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -51,6 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
...
@@ -51,6 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
paddle/pybind/print_operators_doc.cc
浏览文件 @
9ae1523e
...
@@ -64,6 +64,8 @@ std::string AttrType(paddle::framework::proto::AttrType at) {
...
@@ -64,6 +64,8 @@ std::string AttrType(paddle::framework::proto::AttrType at) {
return
"bool array"
;
return
"bool array"
;
case
paddle
::
framework
::
proto
::
BLOCK
:
case
paddle
::
framework
::
proto
::
BLOCK
:
return
"block id"
;
return
"block id"
;
case
paddle
::
framework
::
proto
::
LONG
:
return
"long"
;
}
}
return
"UNKNOWN"
;
// not possible
return
"UNKNOWN"
;
// not possible
}
}
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
9ae1523e
...
@@ -185,20 +185,33 @@ def fc(input,
...
@@ -185,20 +185,33 @@ def fc(input,
return
helper
.
append_activation
(
pre_activation
)
return
helper
.
append_activation
(
pre_activation
)
def
embedding
(
input
,
size
,
is_sparse
=
False
,
param_attr
=
None
,
dtype
=
'float32'
):
def
embedding
(
input
,
size
,
is_sparse
=
False
,
padding_idx
=
None
,
param_attr
=
None
,
dtype
=
'float32'
):
"""
"""
**Embedding Layer**
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
The result of this lookup is the embedding of each ID in the *input*.
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
All the input variables are passed in as local variables to the LayerHelper
constructor.
constructor.
Args:
Args:
input(Variable): Input to the function
input(Variable): The tensor variable containing the IDs.
size(tuple|list|None): Shape of the look up table parameter
size(tuple|list): The shape of the look up table parameter. It should
is_sparse(bool): Boolean flag that specifying whether the input is sparse
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
...
@@ -218,12 +231,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
...
@@ -218,12 +231,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
w
=
helper
.
create_parameter
(
w
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
tmp
=
helper
.
create_tmp_variable
(
dtype
)
tmp
=
helper
.
create_tmp_variable
(
dtype
)
padding_idx
=
-
1
if
padding_idx
is
None
else
padding_idx
if
padding_idx
>=
0
else
(
size
[
0
]
+
padding_idx
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'lookup_table'
,
type
=
'lookup_table'
,
inputs
=
{
'Ids'
:
input
,
inputs
=
{
'Ids'
:
input
,
'W'
:
w
},
'W'
:
w
},
outputs
=
{
'Out'
:
tmp
},
outputs
=
{
'Out'
:
tmp
},
attrs
=
{
'is_sparse'
:
is_sparse
})
attrs
=
{
'is_sparse'
:
is_sparse
,
'padding_idx'
:
padding_idx
})
return
tmp
return
tmp
...
...
python/paddle/v2/fluid/tests/test_lookup_table_op.py
浏览文件 @
9ae1523e
...
@@ -33,5 +33,19 @@ class TestLookupTableOp(OpTest):
...
@@ -33,5 +33,19 @@ class TestLookupTableOp(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
long
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录