Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
867fc053
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
867fc053
编写于
3月 20, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish code
上级
7ba14d74
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
179 addition
and
127 deletion
+179
-127
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
+36
-26
paddle/phi/kernels/cpu/embedding_kernel.cc
paddle/phi/kernels/cpu/embedding_kernel.cc
+17
-12
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
...le/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
+39
-25
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
+18
-12
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+43
-31
paddle/phi/kernels/gpu/embedding_kernel.cu
paddle/phi/kernels/gpu/embedding_kernel.cu
+26
-21
未找到文件。
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
浏览文件 @
867fc053
...
...
@@ -15,16 +15,15 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
GradCPUFunctor
{
LookupTableV2
GradCPUFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
GradCPUFunctor
{
Embedding
GradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -48,7 +47,6 @@ struct LookupTableV2GradCPUFunctor {
// paddings makes no sense and we don't deal with it in backward.
{
auto
*
d_output
=
&
out_grad_
;
// auto d_table = weight_grad_;
auto
*
ids_data
=
ids
.
data
();
int64_t
N
=
table_dim
[
0
];
...
...
@@ -70,7 +68,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data
[
i
],
N
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
...
...
@@ -79,7 +78,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
...
...
@@ -108,15 +108,20 @@ void EmbeddingGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
LookupTableV2
GradCPUFunctor
<
T
,
Context
>
functor
(
Embedding
GradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
SparseGradCPUFunctor
{
LookupTableV2
SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
SparseGradCPUFunctor
{
Embedding
SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -145,7 +150,7 @@ struct LookupTableV2SparseGradCPUFunctor {
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table_dim
[
1
]});
d
_table_value
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
()
);
d
ev_ctx_
.
template
Alloc
<
T
>(
d_table_value
);
d_table
->
set_height
(
table_dim
[
0
]);
...
...
@@ -183,10 +188,15 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
LookupTableV2
SparseGradCPUFunctor
<
T
,
Context
>
functor
(
Embedding
SparseGradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/cpu/embedding_kernel.cc
浏览文件 @
867fc053
...
...
@@ -15,16 +15,16 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
CPUFunctor
{
LookupTableV2
CPUFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
CPUFunctor
{
Embedding
CPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
...
...
@@ -91,10 +91,15 @@ void EmbeddingKernel(const Context& ctx,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2CPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
EmbeddingCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
浏览文件 @
867fc053
...
...
@@ -15,16 +15,16 @@
#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
SparseWeight
LookupTableV2
GradCPUFunctor
{
SparseWeight
LookupTableV2
GradCPUFunctor
(
const
Context
&
dev_ctx
,
struct
SparseWeight
Embedding
GradCPUFunctor
{
SparseWeight
Embedding
GradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -70,7 +70,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data
[
i
],
N
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
...
...
@@ -79,7 +80,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
...
...
@@ -102,8 +104,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
};
template
<
typename
T
,
typename
Context
>
struct
SparseWeight
LookupTableV2
SparseGradCPUFunctor
{
SparseWeight
LookupTableV2
SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
struct
SparseWeight
Embedding
SparseGradCPUFunctor
{
SparseWeight
Embedding
SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -132,7 +134,7 @@ struct SparseWeightLookupTableV2SparseGradCPUFunctor {
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table_dim
[
1
]});
d
_table_value
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
()
);
d
ev_ctx_
.
template
Alloc
<
T
>(
d_table_value
);
d_table
->
set_height
(
table_dim
[
0
]);
...
...
@@ -170,10 +172,16 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
SparseWeight
LookupTableV2
GradCPUFunctor
<
T
,
Context
>
functor
(
SparseWeight
Embedding
GradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
template
<
typename
T
,
typename
Context
>
...
...
@@ -183,10 +191,16 @@ void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
SparseWeight
LookupTableV2
SparseGradCPUFunctor
<
T
,
Context
>
functor
(
SparseWeight
Embedding
SparseGradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
浏览文件 @
867fc053
...
...
@@ -15,17 +15,17 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
CPUSparseFunctor
{
LookupTableV2
CPUSparseFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
CPUSparseFunctor
{
Embedding
CPUSparseFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
...
...
@@ -45,7 +45,7 @@ struct LookupTableV2CPUSparseFunctor {
auto
output_t
=
out_
;
int64_t
row_width
=
table_t
.
value
().
dims
()[
1
];
const
auto
*
table
=
table_t
.
value
().
template
data
<
T
>();
auto
*
output
=
output_t
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
()
);
auto
*
output
=
dev_ctx_
.
template
Alloc
<
T
>(
output_t
);
auto
input_data_type
=
paddle
::
framework
::
TransToProtoVarType
(
table_t
.
value
().
dtype
());
...
...
@@ -94,10 +94,16 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2
CPUSparseFunctor
<
T
,
Context
>
functor
(
Embedding
CPUSparseFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
867fc053
...
...
@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
...
...
@@ -36,7 +36,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
}
template
<
typename
T
,
typename
IdT
>
__global__
void
LookupTableV2
Grad
(
T
*
table
,
__global__
void
Embedding
Grad
(
T
*
table
,
const
T
*
output
,
const
IdT
*
ids
,
const
int64_t
N
,
...
...
@@ -61,8 +61,8 @@ __global__ void LookupTableV2Grad(T* table,
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
GradCUDAFunctor
{
LookupTableV2
GradCUDAFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
GradCUDAFunctor
{
Embedding
GradCUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -89,7 +89,7 @@ struct LookupTableV2GradCUDAFunctor {
const
T
*
d_output
=
d_output_t
.
template
data
<
T
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
T
*
d_table
=
d
_table_t
->
mutable_data
<
T
>
(
dev_ctx_
.
GetPlace
()
);
T
*
d_table
=
d
ev_ctx_
.
template
Alloc
<
T
>(
d_table_t
);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
...
...
@@ -102,7 +102,7 @@ struct LookupTableV2GradCUDAFunctor {
const
int
gridx
=
2
*
dev_ctx_
.
GetSMCount
();
dim3
threads
(
128
,
8
);
dim3
grids
(
gridx
,
1
);
LookupTableV2
Grad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
Embedding
Grad
<
T
,
IdT
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
...
...
@@ -123,15 +123,21 @@ void EmbeddingGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
LookupTableV2
GradCUDAFunctor
<
T
,
Context
>
functor
(
Embedding
GradCUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
SparseGradCUDAFunctor
{
LookupTableV2
SparseGradCUDAFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
SparseGradCUDAFunctor
{
Embedding
SparseGradCUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
...
...
@@ -179,7 +185,7 @@ struct LookupTableV2SparseGradCUDAFunctor {
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table
->
dims
()[
1
]});
d
_table_value
->
template
mutable_data
<
T
>(
gpu_plac
e
);
d
ev_ctx_
.
template
Alloc
<
T
>(
d_table_valu
e
);
auto
*
d_table_data
=
d_table_value
->
template
data
<
T
>();
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
...
...
@@ -219,10 +225,16 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
LookupTableV2
SparseGradCUDAFunctor
<
T
,
Context
>
functor
(
Embedding
SparseGradCUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
paddle/phi/kernels/gpu/embedding_kernel.cu
浏览文件 @
867fc053
...
...
@@ -15,16 +15,15 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace
phi
{
template
<
typename
T
,
typename
IdT
,
bool
PaddingFlag
>
__global__
void
LookupTableV2
(
T
*
output
,
__global__
void
EmbeddingFW
(
T
*
output
,
const
T
*
table
,
const
IdT
*
ids
,
const
int64_t
N
,
...
...
@@ -53,8 +52,8 @@ __global__ void LookupTableV2(T *output,
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2
CUDAFunctor
{
LookupTableV2
CUDAFunctor
(
const
Context
&
dev_ctx
,
struct
Embedding
CUDAFunctor
{
Embedding
CUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
...
...
@@ -77,14 +76,14 @@ struct LookupTableV2CUDAFunctor {
const
T
*
table
=
weight_
.
template
data
<
T
>();
const
IdT
*
ids
=
input_
.
template
data
<
IdT
>();
auto
*
output
=
out_
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
()
);
auto
*
output
=
dev_ctx_
.
template
Alloc
<
T
>(
out_
);
auto
stream
=
dev_ctx_
.
stream
();
if
(
padding_idx_
==
-
1
)
{
LookupTableV2
<
T
,
IdT
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
EmbeddingFW
<
T
,
IdT
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
else
{
LookupTableV2
<
T
,
IdT
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
EmbeddingFW
<
T
,
IdT
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
}
...
...
@@ -103,10 +102,16 @@ void EmbeddingKernel(const Context &ctx,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2
CUDAFunctor
<
T
,
Context
>
functor
(
Embedding
CUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT32
)
{
functor
.
template
apply
<
int32_t
>();
}
else
if
(
input
.
dtype
()
==
phi
::
DataType
::
INT64
)
{
functor
.
template
apply
<
int64_t
>();
}
else
{
PADDLE_THROW
(
"emebdding input only support int32 and int64"
);
}
}
}
// namespace phi
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录