Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
ae676a60
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae676a60
编写于
1月 22, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enhance lookup_table_op to support padding_idx
上级
9247aee7
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
119 addition
and
34 deletion
+119
-34
paddle/framework/attribute.cc
paddle/framework/attribute.cc
+3
-0
paddle/framework/attribute.h
paddle/framework/attribute.h
+26
-0
paddle/framework/framework.proto
paddle/framework/framework.proto
+2
-0
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+1
-0
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+5
-5
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+26
-7
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+15
-12
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+25
-9
python/paddle/v2/fluid/tests/test_lookup_table_op.py
python/paddle/v2/fluid/tests/test_lookup_table_op.py
+14
-0
未找到文件。
paddle/framework/attribute.cc
浏览文件 @
ae676a60
...
...
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
return
val
;
}
case
proto
::
AttrType
::
LONG
:
{
return
attr_desc
.
l
();
}
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_desc
.
type
());
}
...
...
paddle/framework/attribute.h
浏览文件 @
ae676a60
...
...
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
const
std
::
string
&
attr_name_
;
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
attr
.
type
().
name
());
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
template
<
typename
T
>
...
...
paddle/framework/framework.proto
浏览文件 @
ae676a60
...
...
@@ -26,6 +26,7 @@ enum AttrType {
BOOLEAN
=
6
;
BOOLEANS
=
7
;
BLOCK
=
8
;
LONG
=
9
;
}
// OpDesc describes an instance of a C++ framework::OperatorBase
...
...
@@ -44,6 +45,7 @@ message OpDesc {
optional
bool
b
=
10
;
repeated
bool
bools
=
11
;
optional
int32
block_idx
=
12
;
optional
int64
l
=
13
;
};
message
Var
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
ae676a60
...
...
@@ -282,6 +282,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
...
...
paddle/framework/type_defs.h
浏览文件 @
ae676a60
...
...
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using
Attribute
=
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*>
;
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
ae676a60
...
...
@@ -66,11 +66,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"Sparse update"
)
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64_t, default -1)
"
" If given, pads the output with zeros whenever it encounters
"
"the index
."
)
AddAttr
<
int64_t
>
(
"padding_idx"
,
"(int64, default -1) "
"If the value is -1, it makes no effect to lookup.
"
"Otherwise the given value indicates padding the output
"
"with zeros whenever lookup encounters it in Ids
."
)
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
Lookup Table Operator.
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
ae676a60
...
...
@@ -21,9 +21,11 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
...
...
@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
out
[
i
]
=
tab
[
i
];
if
(
PaddingFlag
)
{
if
(
idx
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
else
out
[
i
]
=
tab
[
i
];
}
else
{
out
[
i
]
=
tab
[
i
];
}
}
idy
+=
BlockDimY
*
GridDimX
;
}
...
...
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
...
...
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
);
if
(
padding_idx
==
-
1
)
LookupTable
<
T
,
128
,
8
,
8
,
false
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
else
LookupTable
<
T
,
128
,
8
,
8
,
true
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
};
...
...
@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
paddle/operators/lookup_table_op.h
浏览文件 @
ae676a60
...
...
@@ -39,14 +39,23 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
}
else
{
if
(
padding_idx
==
-
1
)
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
}
}
};
...
...
@@ -56,8 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
@@ -70,9 +79,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework
::
Vector
<
int64_t
>
new_rows
;
new_rows
.
reserve
(
ids_dim
[
0
]);
for
(
int64_t
i
=
0
;
i
<
ids_dim
[
0
];
i
++
)
{
if
(
ids_data
[
i
]
==
padding_idx
)
continue
;
// Paddings are not trainable and the gradient are not
// necessary.
new_rows
.
push_back
(
ids_data
[
i
]);
}
d_table
->
set_rows
(
new_rows
);
...
...
@@ -106,9 +112,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset
(
d_table_data
,
0
,
d_table
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids
->
numel
();
++
i
)
{
if
(
ids_data
[
i
]
==
padding_idx
)
continue
;
// Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
);
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
ae676a60
...
...
@@ -471,6 +471,7 @@ class Operator(object):
self
.
desc
.
set_serialized_attr
(
attr_name
,
attrs
[
attr_name
].
serialize_to_string
())
else
:
# print 'haha', attrs[attr_name], type(attrs[attr_name])
self
.
desc
.
set_attr
(
attr_name
,
attrs
[
attr_name
])
self
.
desc
.
check_attrs
()
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
ae676a60
...
...
@@ -176,22 +176,35 @@ def fc(input,
return
helper
.
append_activation
(
pre_activation
)
def
embedding
(
input
,
size
,
is_sparse
=
False
,
param_attr
=
None
,
dtype
=
'float32'
):
def
embedding
(
input
,
size
,
is_sparse
=
False
,
padding_idx
=
None
,
param_attr
=
None
,
dtype
=
'float32'
):
"""
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
The result of this lookup is the embedding of each ID in the *input*.
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
constructor.
Args:
input(Variable): Input to the function
size(tuple|list|None): Shape of the look up table parameter
is_sparse(bool): Boolean flag that specifying whether the input is sparse
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
input(Variable): The tensor variable containing the IDs.
size(tuple|list): The shape of the look up table parameter. It should
have two elements which indicate the size of the dictionary of
embeddings and the size of each embedding vector respectively.
is_sparse(bool): The flag indicating whether to use sparse update.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
Returns:
Variable: The tensor variable storing the embeddings of the
\
...
...
@@ -209,12 +222,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
w
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
tmp
=
helper
.
create_tmp_variable
(
dtype
)
padding_idx
=
-
1
if
padding_idx
is
None
else
padding_idx
if
padding_idx
>=
0
else
(
size
[
0
]
+
padding_idx
)
helper
.
append_op
(
type
=
'lookup_table'
,
inputs
=
{
'Ids'
:
input
,
'W'
:
w
},
outputs
=
{
'Out'
:
tmp
},
attrs
=
{
'is_sparse'
:
is_sparse
})
attrs
=
{
'is_sparse'
:
is_sparse
,
'padding_idx'
:
padding_idx
})
return
tmp
...
...
python/paddle/v2/fluid/tests/test_lookup_table_op.py
浏览文件 @
ae676a60
...
...
@@ -32,5 +32,19 @@ class TestLookupTableOp(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
long
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录