Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
ae676a60
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ae676a60
编写于
1月 22, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Enhance lookup_table_op to support padding_idx
上级
9247aee7
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
119 addition
and
34 deletion
+119
-34
paddle/framework/attribute.cc
paddle/framework/attribute.cc
+3
-0
paddle/framework/attribute.h
paddle/framework/attribute.h
+26
-0
paddle/framework/framework.proto
paddle/framework/framework.proto
+2
-0
paddle/framework/op_desc.cc
paddle/framework/op_desc.cc
+1
-0
paddle/framework/type_defs.h
paddle/framework/type_defs.h
+1
-1
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+5
-5
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+26
-7
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+15
-12
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-0
python/paddle/v2/fluid/layers/nn.py
python/paddle/v2/fluid/layers/nn.py
+25
-9
python/paddle/v2/fluid/tests/test_lookup_table_op.py
python/paddle/v2/fluid/tests/test_lookup_table_op.py
+14
-0
未找到文件。
paddle/framework/attribute.cc
浏览文件 @
ae676a60
...
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
...
@@ -61,6 +61,9 @@ Attribute GetAttrValue(const proto::OpDesc::Attr& attr_desc) {
}
}
return
val
;
return
val
;
}
}
case
proto
::
AttrType
::
LONG
:
{
return
attr_desc
.
l
();
}
default:
default:
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_desc
.
type
());
PADDLE_THROW
(
"Unsupport attr type %d"
,
attr_desc
.
type
());
}
}
...
...
paddle/framework/attribute.h
浏览文件 @
ae676a60
...
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
...
@@ -168,6 +168,32 @@ struct ExtractAttribute<bool> {
const
std
::
string
&
attr_name_
;
const
std
::
string
&
attr_name_
;
};
};
template
<
>
struct
ExtractAttribute
<
int64_t
>
{
explicit
ExtractAttribute
(
const
std
::
string
&
attr_name
)
:
attr_name_
(
attr_name
)
{}
int64_t
*
operator
()(
Attribute
&
attr
)
const
{
if
(
attr
.
type
()
==
typeid
(
int
))
{
// NOLINT
int
val
=
boost
::
get
<
int
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
else
if
(
attr
.
type
()
==
typeid
(
float
))
{
// NOLINT
int
val
=
boost
::
get
<
float
>
(
attr
);
attr
=
static_cast
<
int64_t
>
(
val
);
}
int64_t
*
attr_value
=
nullptr
;
try
{
attr_value
=
&
boost
::
get
<
int64_t
>
(
attr
);
}
catch
(
boost
::
bad_get
&
bad_get
)
{
PADDLE_THROW
(
"Cannot get attribute %s by type int64_t, its type is %s"
,
attr_name_
,
attr
.
type
().
name
());
}
return
attr_value
;
}
const
std
::
string
&
attr_name_
;
};
// check whether a certain attribute fit its limits
// check whether a certain attribute fit its limits
// an attribute can have more than one limits
// an attribute can have more than one limits
template
<
typename
T
>
template
<
typename
T
>
...
...
paddle/framework/framework.proto
浏览文件 @
ae676a60
...
@@ -26,6 +26,7 @@ enum AttrType {
...
@@ -26,6 +26,7 @@ enum AttrType {
BOOLEAN
=
6
;
BOOLEAN
=
6
;
BOOLEANS
=
7
;
BOOLEANS
=
7
;
BLOCK
=
8
;
BLOCK
=
8
;
LONG
=
9
;
}
}
// OpDesc describes an instance of a C++ framework::OperatorBase
// OpDesc describes an instance of a C++ framework::OperatorBase
...
@@ -44,6 +45,7 @@ message OpDesc {
...
@@ -44,6 +45,7 @@ message OpDesc {
optional
bool
b
=
10
;
optional
bool
b
=
10
;
repeated
bool
bools
=
11
;
repeated
bool
bools
=
11
;
optional
int32
block_idx
=
12
;
optional
int32
block_idx
=
12
;
optional
int64
l
=
13
;
};
};
message
Var
{
message
Var
{
...
...
paddle/framework/op_desc.cc
浏览文件 @
ae676a60
...
@@ -282,6 +282,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
...
@@ -282,6 +282,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
VectorToRepeated
(
v
,
attr_
->
mutable_bools
());
}
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
BlockDesc
*
desc
)
const
{
attr_
->
set_block_idx
(
desc
->
ID
());
}
void
operator
()(
int64_t
v
)
const
{
attr_
->
set_l
(
v
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
void
operator
()(
boost
::
blank
)
const
{
PADDLE_THROW
(
"Unexpected branch"
);
}
};
};
...
...
paddle/framework/type_defs.h
浏览文件 @
ae676a60
...
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
...
@@ -35,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using
Attribute
=
using
Attribute
=
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
boost
::
variant
<
boost
::
blank
,
int
,
float
,
std
::
string
,
std
::
vector
<
int
>
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
float
>
,
std
::
vector
<
std
::
string
>
,
bool
,
std
::
vector
<
bool
>
,
BlockDesc
*>
;
std
::
vector
<
bool
>
,
BlockDesc
*
,
int64_t
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
using
AttributeMap
=
std
::
unordered_map
<
std
::
string
,
Attribute
>
;
...
...
paddle/operators/lookup_table_op.cc
浏览文件 @
ae676a60
...
@@ -66,11 +66,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -66,11 +66,11 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(boolean, default false) "
"(boolean, default false) "
"Sparse update"
)
"Sparse update"
)
.
SetDefault
(
false
);
.
SetDefault
(
false
);
AddAttr
<
int64_t
>
(
AddAttr
<
int64_t
>
(
"padding_idx"
,
"padding_idx"
,
"(int64, default -1) "
"(int64_t, default -1)
"
"If the value is -1, it makes no effect to lookup.
"
" If given, pads the output with zeros whenever it encounters
"
"Otherwise the given value indicates padding the output
"
"the index
."
)
"with zeros whenever lookup encounters it in Ids
."
)
.
SetDefault
(
-
1
);
.
SetDefault
(
-
1
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Lookup Table Operator.
Lookup Table Operator.
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
ae676a60
...
@@ -21,9 +21,11 @@ limitations under the License. */
...
@@ -21,9 +21,11 @@ limitations under the License. */
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
template
<
typename
T
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int64_t
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
...
@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
...
@@ -34,7 +36,14 @@ __global__ void LookupTable(T* output, const T* table, const int64_t* ids,
T
*
out
=
output
+
idy
*
D
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
out
[
i
]
=
tab
[
i
];
if
(
PaddingFlag
)
{
if
(
idx
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
else
out
[
i
]
=
tab
[
i
];
}
else
{
out
[
i
]
=
tab
[
i
];
}
}
}
idy
+=
BlockDimY
*
GridDimX
;
idy
+=
BlockDimY
*
GridDimX
;
}
}
...
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -67,6 +76,7 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table_t
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids_t
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
auto
*
output_t
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
...
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
...
@@ -77,10 +87,17 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
LookupTable
<
T
,
128
,
8
,
if
(
padding_idx
==
-
1
)
8
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
LookupTable
<
output
,
table
,
ids
,
N
,
K
,
D
);
T
,
128
,
8
,
8
,
false
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
else
LookupTable
<
T
,
128
,
8
,
8
,
true
><<<
grids
,
threads
,
0
,
context
.
cuda_device_context
().
stream
()
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx
);
}
}
};
};
...
@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
...
@@ -91,6 +108,8 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
auto
&
dev_ctx
=
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
context
.
template
device_context
<
platform
::
CUDADeviceContext
>();
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
...
paddle/operators/lookup_table_op.h
浏览文件 @
ae676a60
...
@@ -39,14 +39,23 @@ class LookupTableKernel : public framework::OpKernel<T> {
...
@@ -39,14 +39,23 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
ids
=
ids_t
->
data
<
int64_t
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
if
(
padding_idx
==
-
1
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
}
else
{
for
(
int64_t
i
=
0
;
i
<
ids_t
->
numel
();
++
i
)
{
if
(
ids
[
i
]
==
padding_idx
)
{
memset
(
output
+
i
*
D
,
0
,
D
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
);
memcpy
(
output
+
i
*
D
,
table
+
ids
[
i
]
*
D
,
D
*
sizeof
(
T
));
}
}
}
}
}
}
};
};
...
@@ -56,8 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
...
@@ -56,8 +65,8 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
bool
is_sparse
=
context
.
Attr
<
bool
>
(
"is_sparse"
);
int64_t
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
if
(
is_sparse
)
{
if
(
is_sparse
)
{
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
ids
=
context
.
Input
<
LoDTensor
>
(
"Ids"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
auto
*
table
=
context
.
Input
<
LoDTensor
>
(
"W"
);
...
@@ -70,9 +79,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
...
@@ -70,9 +79,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
framework
::
Vector
<
int64_t
>
new_rows
;
framework
::
Vector
<
int64_t
>
new_rows
;
new_rows
.
reserve
(
ids_dim
[
0
]);
new_rows
.
reserve
(
ids_dim
[
0
]);
for
(
int64_t
i
=
0
;
i
<
ids_dim
[
0
];
i
++
)
{
for
(
int64_t
i
=
0
;
i
<
ids_dim
[
0
];
i
++
)
{
if
(
ids_data
[
i
]
==
padding_idx
)
continue
;
// Paddings are not trainable and the gradient are not
// necessary.
new_rows
.
push_back
(
ids_data
[
i
]);
new_rows
.
push_back
(
ids_data
[
i
]);
}
}
d_table
->
set_rows
(
new_rows
);
d_table
->
set_rows
(
new_rows
);
...
@@ -106,9 +112,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
...
@@ -106,9 +112,6 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
memset
(
d_table_data
,
0
,
d_table
->
numel
()
*
sizeof
(
T
));
memset
(
d_table_data
,
0
,
d_table
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids
->
numel
();
++
i
)
{
for
(
int64_t
i
=
0
;
i
<
ids
->
numel
();
++
i
)
{
if
(
ids_data
[
i
]
==
padding_idx
)
continue
;
// Paddings are not trainable and the gradient are not
// necessary.
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
);
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
);
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
);
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
ae676a60
...
@@ -471,6 +471,7 @@ class Operator(object):
...
@@ -471,6 +471,7 @@ class Operator(object):
self
.
desc
.
set_serialized_attr
(
self
.
desc
.
set_serialized_attr
(
attr_name
,
attrs
[
attr_name
].
serialize_to_string
())
attr_name
,
attrs
[
attr_name
].
serialize_to_string
())
else
:
else
:
# print 'haha', attrs[attr_name], type(attrs[attr_name])
self
.
desc
.
set_attr
(
attr_name
,
attrs
[
attr_name
])
self
.
desc
.
set_attr
(
attr_name
,
attrs
[
attr_name
])
self
.
desc
.
check_attrs
()
self
.
desc
.
check_attrs
()
...
...
python/paddle/v2/fluid/layers/nn.py
浏览文件 @
ae676a60
...
@@ -176,22 +176,35 @@ def fc(input,
...
@@ -176,22 +176,35 @@ def fc(input,
return
helper
.
append_activation
(
pre_activation
)
return
helper
.
append_activation
(
pre_activation
)
def
embedding
(
input
,
size
,
is_sparse
=
False
,
param_attr
=
None
,
dtype
=
'float32'
):
def
embedding
(
input
,
size
,
is_sparse
=
False
,
padding_idx
=
None
,
param_attr
=
None
,
dtype
=
'float32'
):
"""
"""
**Embedding Layer**
**Embedding Layer**
This layer is used to lookup a vector of IDs, provided by *input*, in a lookup table.
This layer is used to lookup embeddings of IDs, provided by :attr:`input`, in
The result of this lookup is the embedding of each ID in the *input*.
a lookup table. The result of this lookup is the embedding of each ID in the
:attr:`input`.
All the input variables are passed in as local variables to the LayerHelper
All the input variables are passed in as local variables to the LayerHelper
constructor.
constructor.
Args:
Args:
input(Variable): Input to the function
input(Variable): The tensor variable containing the IDs.
size(tuple|list|None): Shape of the look up table parameter
size(tuple|list): The shape of the look up table parameter. It should
is_sparse(bool): Boolean flag that specifying whether the input is sparse
have two elements which indicate the size of the dictionary of
param_attr(ParamAttr): Parameters for this layer
embeddings and the size of each embedding vector respectively.
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
is_sparse(bool): The flag indicating whether to use sparse update.
padding_idx(int|long|None): If :attr:`None`, it makes no effect to lookup.
Otherwise the given :attr:`padding_idx` indicates padding the output
with zeros whenever lookup encounters it in :attr:`input`. If
:math:`padding_idx < 0`, the padding_idx to use in lookup is
:math:`size[0] + dim`.
param_attr(ParamAttr): Parameters for this layer
dtype(np.dtype|core.DataType|str): The type of data : float32, float_16, int etc
Returns:
Returns:
Variable: The tensor variable storing the embeddings of the
\
Variable: The tensor variable storing the embeddings of the
\
...
@@ -209,12 +222,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
...
@@ -209,12 +222,15 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
w
=
helper
.
create_parameter
(
w
=
helper
.
create_parameter
(
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
attr
=
helper
.
param_attr
,
shape
=
size
,
dtype
=
dtype
,
is_bias
=
False
)
tmp
=
helper
.
create_tmp_variable
(
dtype
)
tmp
=
helper
.
create_tmp_variable
(
dtype
)
padding_idx
=
-
1
if
padding_idx
is
None
else
padding_idx
if
padding_idx
>=
0
else
(
size
[
0
]
+
padding_idx
)
helper
.
append_op
(
helper
.
append_op
(
type
=
'lookup_table'
,
type
=
'lookup_table'
,
inputs
=
{
'Ids'
:
input
,
inputs
=
{
'Ids'
:
input
,
'W'
:
w
},
'W'
:
w
},
outputs
=
{
'Out'
:
tmp
},
outputs
=
{
'Out'
:
tmp
},
attrs
=
{
'is_sparse'
:
is_sparse
})
attrs
=
{
'is_sparse'
:
is_sparse
,
'padding_idx'
:
padding_idx
})
return
tmp
return
tmp
...
...
python/paddle/v2/fluid/tests/test_lookup_table_op.py
浏览文件 @
ae676a60
...
@@ -32,5 +32,19 @@ class TestLookupTableOp(OpTest):
...
@@ -32,5 +32,19 @@ class TestLookupTableOp(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
long
(
padding_idx
)}
self
.
check_output
()
def
test_check_grad
(
self
):
# Since paddings are not trainable and fixed in forward, the gradient of
# paddings makes no sense and we don't test the gradient here.
pass
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录