Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
a8d072c7
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
695
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a8d072c7
编写于
8月 23, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix bug.
上级
9bc1a1a1
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
25 addition
and
26 deletion
+25
-26
paddle/operators/lookup_table_op.cc
paddle/operators/lookup_table_op.cc
+4
-3
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+16
-16
paddle/operators/lookup_table_op.h
paddle/operators/lookup_table_op.h
+3
-3
python/paddle/v2/framework/tests/test_lookup_table.py
python/paddle/v2/framework/tests/test_lookup_table.py
+2
-4
未找到文件。
paddle/operators/lookup_table_op.cc
浏览文件 @
a8d072c7
...
@@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -41,8 +41,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
" which is a learnable parameter."
);
" which is a learnable parameter."
);
AddInput
(
"Ids"
,
AddInput
(
"Ids"
,
"An input with type int32 or int64"
"An input with type int32 or int64"
"contains the ids to be looked up in W."
)
"contains the ids to be looked up in W."
);
.
NotInGradient
();
AddOutput
(
"Out"
,
"The lookup results, which have the same type with W."
);
AddOutput
(
"Out"
,
"The lookup results, which have the same type with W."
);
AddComment
(
AddComment
(
"This operator is used to perform lookups on the parameter W,"
"This operator is used to perform lookups on the parameter W,"
...
@@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
...
@@ -56,7 +55,9 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
void
InferShape
(
const
framework
::
InferShapeContext
&
context
)
const
override
{
context
.
Output
<
Tensor
>
(
0
)
->
Resize
(
context
.
Input
<
Tensor
>
(
0
)
->
dims
());
auto
table
=
context
.
Input
<
Tensor
>
(
"W"
);
auto
d_table
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"W"
));
d_table
->
Resize
(
table
->
dims
());
}
}
};
};
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
a8d072c7
...
@@ -23,7 +23,7 @@ namespace operators {
...
@@ -23,7 +23,7 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
gridDimX
>
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
gridDimX
>
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
u
int32_t
*
ids
,
__global__
void
LookupTable
(
T
*
output
,
const
T
*
table
,
const
int32_t
*
ids
,
const
int
N
,
const
int
K
,
const
int
D
)
{
const
int
N
,
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDimX
;
...
@@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
...
@@ -32,8 +32,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
int
id
=
ids
[
idy
];
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
T
*
out
=
output
+
idy
;
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDimX
)
{
out
[
i
]
=
tab
[
i
];
out
[
i
]
=
tab
[
i
];
}
}
...
@@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
...
@@ -42,9 +42,8 @@ __global__ void LookupTable(T* output, const T* table, const uint32_t* ids,
}
}
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
gridDimX
>
template
<
typename
T
,
int
blockDimX
,
int
blockDimY
,
int
gridDimX
>
__global__
void
LookupTableGradKernel
(
T
*
table
,
const
T
*
output
,
__global__
void
LookupTableGrad
(
T
*
table
,
const
T
*
output
,
const
int32_t
*
ids
,
const
uint32_t
*
ids
,
const
int
N
,
const
int
N
,
const
int
K
,
const
int
D
)
{
const
int
K
,
const
int
D
)
{
int
idx
=
threadIdx
.
x
;
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDimX
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
gridDimX
;
...
@@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output,
...
@@ -52,10 +51,10 @@ __global__ void LookupTableGradKernel(T* table, const T* output,
int
id
=
ids
[
idy
];
int
id
=
ids
[
idy
];
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
>=
0
);
PADDLE_ASSERT
(
id
<
N
);
PADDLE_ASSERT
(
id
<
N
);
const
T
*
out
=
output
+
idy
;
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDimX
)
{
for
(
int
i
=
idx
;
i
<
D
;
i
+=
blockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
tab
+
i
,
out
[
i
]);
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
]
,
out
[
i
]);
}
}
idy
+=
blockDimY
*
gridDimX
;
idy
+=
blockDimY
*
gridDimX
;
}
}
...
@@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
...
@@ -72,7 +71,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
K
=
product
(
ids_t
->
dims
());
size_t
K
=
product
(
ids_t
->
dims
());
auto
ids
=
ids_t
->
data
<
u
int32_t
>
();
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
...
@@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
...
@@ -83,7 +82,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
};
};
template
<
typename
T
>
template
<
typename
T
>
class
LookupTableGrad
:
public
framework
::
OpKernel
{
class
LookupTableGrad
CUDAKernel
:
public
framework
::
OpKernel
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
auto
ids_t
=
context
.
Input
<
Tensor
>
(
"Ids"
);
...
@@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel {
...
@@ -93,9 +92,9 @@ class LookupTableGrad : public framework::OpKernel {
int
N
=
d_table_t
->
dims
()[
0
];
int
N
=
d_table_t
->
dims
()[
0
];
int
D
=
d_table_t
->
dims
()[
1
];
int
D
=
d_table_t
->
dims
()[
1
];
int
K
=
product
(
ids_t
->
dims
());
int
K
=
product
(
ids_t
->
dims
());
const
uint32_t
*
ids
=
ids_t
->
data
<
uint32_t
>
();
const
int32_t
*
ids
=
ids_t
->
data
<
int32_t
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
device_context
=
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
...
@@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel {
...
@@ -103,8 +102,8 @@ class LookupTableGrad : public framework::OpKernel {
device_context
);
device_context
);
dim3
threads
(
128
,
8
);
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
dim3
grids
(
8
,
1
);
LookupTableGrad
Kernel
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
LookupTableGrad
<
T
,
128
,
8
,
8
><<<
grids
,
threads
>>>
(
d_table
,
d_output
,
ids
,
N
,
ids
,
N
,
K
,
D
);
K
,
D
);
}
}
};
};
...
@@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel {
...
@@ -113,4 +112,5 @@ class LookupTableGrad : public framework::OpKernel {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
lookup_table
,
ops
::
LookupTableCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
lookup_table
,
ops
::
LookupTableCUDAKernel
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGrad
<
float
>
);
REGISTER_OP_GPU_KERNEL
(
lookup_table_grad
,
ops
::
LookupTableGradCUDAKernel
<
float
>
);
paddle/operators/lookup_table_op.h
浏览文件 @
a8d072c7
...
@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel {
...
@@ -32,7 +32,7 @@ class LookupTableKernel : public framework::OpKernel {
size_t
N
=
table_t
->
dims
()[
0
];
size_t
N
=
table_t
->
dims
()[
0
];
size_t
D
=
table_t
->
dims
()[
1
];
size_t
D
=
table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
u
int32_t
>
();
auto
ids
=
ids_t
->
data
<
int32_t
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
table
=
table_t
->
data
<
T
>
();
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
output
=
output_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
for
(
size_t
i
=
0
;
i
<
product
(
ids_t
->
dims
());
++
i
)
{
...
@@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel {
...
@@ -53,9 +53,9 @@ class LookupTableGradKernel : public framework::OpKernel {
size_t
N
=
d_table_t
->
dims
()[
0
];
size_t
N
=
d_table_t
->
dims
()[
0
];
size_t
D
=
d_table_t
->
dims
()[
1
];
size_t
D
=
d_table_t
->
dims
()[
1
];
auto
ids
=
ids_t
->
data
<
uint32_t
>
();
auto
ids
=
ids_t
->
data
<
int32_t
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
const
T
*
d_output
=
d_output_t
->
data
<
T
>
();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
device_context
=
auto
*
device_context
=
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
const_cast
<
platform
::
DeviceContext
*>
(
context
.
device_context_
);
...
...
python/paddle/v2/framework/tests/test_lookup_table.py
浏览文件 @
a8d072c7
...
@@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase):
...
@@ -10,7 +10,7 @@ class TestSigmoidOp(unittest.TestCase):
def
setUp
(
self
):
def
setUp
(
self
):
self
.
type
=
'lookup_table'
self
.
type
=
'lookup_table'
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
)
.
astype
(
'int32'
)
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
self
.
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
self
.
outputs
=
{
'Out'
:
table
[
ids
]}
self
.
outputs
=
{
'Out'
:
table
[
ids
]}
...
@@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker):
...
@@ -19,10 +19,8 @@ class TestSigmoidGradOp(GradientChecker):
def
test_grad
(
self
):
def
test_grad
(
self
):
op
=
create_op
(
'lookup_table'
)
op
=
create_op
(
'lookup_table'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
table
=
np
.
random
.
random
((
17
,
31
)).
astype
(
'float32'
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
)
ids
=
np
.
random
.
randint
(
0
,
17
,
4
)
.
astype
(
'int32'
)
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
inputs
=
{
'W'
:
table
,
'Ids'
:
ids
}
# compare gradients between cpu and gpu
self
.
compare_grad
(
op
,
inputs
)
# check gradients
# check gradients
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
self
.
check_grad
(
op
,
inputs
,
set
(
'W'
),
'Out'
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录