Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e81773c9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e81773c9
编写于
2月 28, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move reset impl to phi; test=develop
上级
ec0e8391
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
551 addition
and
239 deletion
+551
-239
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
+84
-1
paddle/phi/kernels/cpu/embedding_kernel.cc
paddle/phi/kernels/cpu/embedding_kernel.cc
+1
-1
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
...le/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
+93
-10
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
+8
-8
paddle/phi/kernels/embedding_grad_kernel.h
paddle/phi/kernels/embedding_grad_kernel.h
+9
-0
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+108
-1
paddle/phi/kernels/gpu/embedding_kernel.cu
paddle/phi/kernels/gpu/embedding_kernel.cu
+1
-1
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
+8
-0
paddle/phi/ops/compat/embedding_sig.cc
paddle/phi/ops/compat/embedding_sig.cc
+26
-14
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
...n/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
+213
-203
未找到文件。
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
浏览文件 @
e81773c9
...
@@ -114,12 +114,95 @@ void EmbeddingGradKernel(const Context& ctx,
...
@@ -114,12 +114,95 @@ void EmbeddingGradKernel(const Context& ctx,
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2SparseGradCPUFunctor
{
LookupTableV2SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
weight_grad_
(
weight_grad
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
DDim
table_dim
=
weight_
.
dims
();
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_num
=
static_cast
<
int64_t
>
(
ids
.
size
());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto
*
d_table
=
weight_grad_
;
auto
*
d_output
=
&
out_grad_
;
d_table
->
set_rows
(
ids
);
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table_dim
[
1
]});
d_table_value
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
d_table
->
set_height
(
table_dim
[
0
]);
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
auto
*
d_table_data
=
d_table_value
->
template
data
<
T
>();
auto
d_output_dims
=
d_output
->
dims
();
auto
d_output_dims_2d
=
flatten_to_2d
(
d_output_dims
,
d_output_dims
.
size
()
-
1
);
PADDLE_ENFORCE_EQ
(
d_table_value
->
dims
(),
d_output_dims_2d
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s]."
,
d_table_value
->
dims
(),
d_output_dims_2d
));
memcpy
(
d_table_data
,
d_output_data
,
sizeof
(
T
)
*
d_output
->
numel
());
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
const
DenseTensor
&
out_grad_
;
SelectedRows
*
weight_grad_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingSparseGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
LookupTableV2SparseGradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
embedding_grad
,
P
D
_REGISTER_KERNEL
(
embedding_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
EmbeddingGradKernel
,
phi
::
EmbeddingGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
embedding_sparse_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
EmbeddingSparseGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/cpu/embedding_kernel.cc
浏览文件 @
e81773c9
...
@@ -99,7 +99,7 @@ void EmbeddingKernel(const Context& ctx,
...
@@ -99,7 +99,7 @@ void EmbeddingKernel(const Context& ctx,
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
embedding
,
P
D
_REGISTER_KERNEL
(
embedding
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
EmbeddingKernel
,
phi
::
EmbeddingKernel
,
...
...
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
浏览文件 @
e81773c9
...
@@ -12,7 +12,7 @@
...
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/
sparse_weight_
embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/convert_utils.h"
...
@@ -23,13 +23,13 @@
...
@@ -23,13 +23,13 @@
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2GradCPUFunctor
{
struct
SparseWeight
LookupTableV2GradCPUFunctor
{
LookupTableV2GradCPUFunctor
(
const
Context
&
dev_ctx
,
SparseWeight
LookupTableV2GradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
DenseTensor
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
input_
(
input
),
weight_
(
weight
),
weight_
(
weight
),
...
@@ -101,6 +101,68 @@ struct LookupTableV2GradCPUFunctor {
...
@@ -101,6 +101,68 @@ struct LookupTableV2GradCPUFunctor {
int64_t
padding_idx_
;
int64_t
padding_idx_
;
};
};
template
<
typename
T
,
typename
Context
>
struct
SparseWeightLookupTableV2SparseGradCPUFunctor
{
SparseWeightLookupTableV2SparseGradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
weight_grad_
(
weight_grad
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
DDim
table_dim
=
weight_
.
dims
();
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_num
=
static_cast
<
int64_t
>
(
ids
.
size
());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
auto
*
d_table
=
weight_grad_
;
auto
*
d_output
=
&
out_grad_
;
d_table
->
set_rows
(
ids
);
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table_dim
[
1
]});
d_table_value
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
d_table
->
set_height
(
table_dim
[
0
]);
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
auto
*
d_table_data
=
d_table_value
->
template
data
<
T
>();
auto
d_output_dims
=
d_output
->
dims
();
auto
d_output_dims_2d
=
phi
::
flatten_to_2d
(
d_output_dims
,
d_output_dims
.
size
()
-
1
);
PADDLE_ENFORCE_EQ
(
d_table_value
->
dims
(),
d_output_dims_2d
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s]."
,
d_table_value
->
dims
(),
d_output_dims_2d
));
memcpy
(
d_table_data
,
d_output_data
,
sizeof
(
T
)
*
d_output
->
numel
());
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
SelectedRows
&
weight_
;
const
DenseTensor
&
out_grad_
;
SelectedRows
*
weight_grad_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingGradKernel
(
const
Context
&
ctx
,
void
SparseWeightEmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
...
@@ -108,7 +170,20 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
...
@@ -108,7 +170,20 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
const
DenseTensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
DenseTensor
*
weight_grad
)
{
LookupTableV2GradCPUFunctor
<
T
,
Context
>
functor
(
SparseWeightLookupTableV2GradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingSparseGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
SparseWeightLookupTableV2SparseGradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
...
@@ -116,10 +191,18 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
...
@@ -116,10 +191,18 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
sparse_weight_embedding_grad
,
P
D
_REGISTER_KERNEL
(
sparse_weight_embedding_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
SparseWeightEmbeddingGradKernel
,
phi
::
SparseWeightEmbeddingGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
sparse_weight_embedding_sparse_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
SparseWeightEmbeddingSparseGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
浏览文件 @
e81773c9
...
@@ -24,12 +24,12 @@
...
@@ -24,12 +24,12 @@
namespace
phi
{
namespace
phi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2CPUFunctor
{
struct
LookupTableV2CPU
Sparse
Functor
{
LookupTableV2CPUFunctor
(
const
Context
&
dev_ctx
,
LookupTableV2CPU
Sparse
Functor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
out
)
DenseTensor
*
out
)
:
dev_ctx_
(
dev_ctx
),
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
input_
(
input
),
weight_
(
weight
),
weight_
(
weight
),
...
@@ -94,7 +94,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
...
@@ -94,7 +94,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
const
SelectedRows
&
weight
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
DenseTensor
*
out
)
{
LookupTableV2CPUFunctor
<
T
,
Context
>
functor
(
LookupTableV2CPU
Sparse
Functor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
...
@@ -102,7 +102,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
...
@@ -102,7 +102,7 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
sparse_weight_embedding
,
P
D
_REGISTER_KERNEL
(
sparse_weight_embedding
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
SparseWeightEmbeddingKernel
,
phi
::
SparseWeightEmbeddingKernel
,
...
...
paddle/phi/kernels/embedding_grad_kernel.h
浏览文件 @
e81773c9
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#pragma once
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
namespace
phi
{
...
@@ -26,4 +27,12 @@ void EmbeddingGradKernel(const Context& ctx,
...
@@ -26,4 +27,12 @@ void EmbeddingGradKernel(const Context& ctx,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
);
DenseTensor
*
weight_grad
);
template
<
typename
T
,
typename
Context
>
void
EmbeddingSparseGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
);
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
浏览文件 @
e81773c9
...
@@ -21,7 +21,9 @@
...
@@ -21,7 +21,9 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
phi
{
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
template
<
typename
InT
,
typename
OutT
>
...
@@ -120,12 +122,117 @@ void EmbeddingGradKernel(const Context& ctx,
...
@@ -120,12 +122,117 @@ void EmbeddingGradKernel(const Context& ctx,
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2SparseGradCUDAFunctor
{
LookupTableV2SparseGradCUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
padding_idx_
(
padding_idx
),
weight_grad_
(
weight_grad
)
{}
template
<
typename
IdT
>
void
apply
()
{
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
const
auto
*
ids_data
=
input_
.
template
data
<
IdT
>();
auto
*
d_table
=
weight_grad_
;
auto
*
table
=
&
weight_
;
auto
*
d_output
=
&
out_grad_
;
int64_t
ids_num
=
input_
.
numel
();
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
auto
stream
=
dev_ctx_
.
stream
();
paddle
::
framework
::
Vector
<
int64_t
>
new_rows
;
new_rows
.
resize
(
ids_num
);
auto
gpu_place
=
dev_ctx_
.
GetPlace
();
paddle
::
framework
::
MixVector
<
int64_t
>
mixv_new_rows
(
&
new_rows
);
if
(
!
std
::
is_same
<
IdT
,
int64_t
>::
value
)
{
InputTypeConvert
<<<
grids
,
threads
,
0
,
stream
>>>
(
ids_data
,
ids_num
,
mixv_new_rows
.
MutableData
(
gpu_place
));
}
else
{
paddle
::
memory
::
Copy
(
gpu_place
,
mixv_new_rows
.
CUDAMutableData
(
gpu_place
),
gpu_place
,
ids_data
,
ids_num
*
sizeof
(
int64_t
),
stream
);
}
mixv_new_rows
.
CopyToCPU
();
d_table
->
set_rows
(
new_rows
);
auto
*
d_table_value
=
d_table
->
mutable_value
();
d_table_value
->
Resize
({
ids_num
,
table
->
dims
()[
1
]});
d_table_value
->
template
mutable_data
<
T
>(
gpu_place
);
auto
*
d_table_data
=
d_table_value
->
template
data
<
T
>();
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
auto
d_output_dims
=
d_output
->
dims
();
auto
d_output_dims_2d
=
phi
::
flatten_to_2d
(
d_output_dims
,
d_output_dims
.
size
()
-
1
);
PADDLE_ENFORCE_EQ
(
d_table_value
->
dims
(),
d_output_dims_2d
,
phi
::
errors
::
InvalidArgument
(
"ShapeError: The shape of lookup_table@Grad and "
"output@Grad should be same. "
"But received lookup_table@Grad's shape = [%s], "
"output@Grad's shape = [%s]."
,
d_table_value
->
dims
(),
d_output_dims_2d
));
paddle
::
memory
::
Copy
(
gpu_place
,
d_table_data
,
gpu_place
,
d_output_data
,
d_output
->
numel
()
*
sizeof
(
T
),
stream
);
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
const
DenseTensor
&
out_grad_
;
int64_t
padding_idx_
;
SelectedRows
*
weight_grad_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingSparseGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
)
{
LookupTableV2SparseGradCUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
embedding_grad
,
P
D
_REGISTER_KERNEL
(
embedding_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
EmbeddingGradKernel
,
phi
::
EmbeddingGradKernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{}
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
embedding_sparse_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
EmbeddingSparseGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/embedding_kernel.cu
浏览文件 @
e81773c9
...
@@ -115,7 +115,7 @@ void EmbeddingKernel(const Context &ctx,
...
@@ -115,7 +115,7 @@ void EmbeddingKernel(const Context &ctx,
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_KERNEL
(
embedding
,
P
D
_REGISTER_KERNEL
(
embedding
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
EmbeddingKernel
,
phi
::
EmbeddingKernel
,
...
...
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
浏览文件 @
e81773c9
...
@@ -27,4 +27,12 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
...
@@ -27,4 +27,12 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
int64_t
padding_idx
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
);
DenseTensor
*
weight_grad
);
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingSparseGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
SelectedRows
*
weight_grad
);
}
// namespace phi
}
// namespace phi
paddle/phi/ops/compat/embedding_sig.cc
浏览文件 @
e81773c9
...
@@ -18,10 +18,8 @@ namespace phi {
...
@@ -18,10 +18,8 @@ namespace phi {
KernelSignature
EmbeddingOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
KernelSignature
EmbeddingOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
IsDenseTensorInput
(
"W"
))
{
if
(
ctx
.
IsDenseTensorInput
(
"W"
))
{
LOG
(
ERROR
)
<<
"is dense here"
;
return
KernelSignature
(
"embedding"
,
{
"Ids"
,
"W"
},
{
"padding_idx"
},
{
"Out"
});
return
KernelSignature
(
"embedding"
,
{
"Ids"
,
"W"
},
{
"padding_idx"
},
{
"Out"
});
}
else
{
}
else
{
LOG
(
ERROR
)
<<
"is selcted rows"
;
return
KernelSignature
(
return
KernelSignature
(
"sparse_weight_embedding"
,
{
"Ids"
,
"W"
},
{
"padding_idx"
},
{
"Out"
});
"sparse_weight_embedding"
,
{
"Ids"
,
"W"
},
{
"padding_idx"
},
{
"Out"
});
}
}
...
@@ -30,23 +28,37 @@ KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
...
@@ -30,23 +28,37 @@ KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature
EmbeddingGradOpArgumentMapping
(
KernelSignature
EmbeddingGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
const
ArgumentMappingContext
&
ctx
)
{
if
(
ctx
.
IsDenseTensorInput
(
"W"
))
{
if
(
ctx
.
IsDenseTensorInput
(
"W"
))
{
return
KernelSignature
(
"embedding_grad"
,
if
((
paddle
::
any_cast
<
bool
>
(
ctx
.
Attr
(
"is_sparse"
)))
==
true
)
{
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
return
KernelSignature
(
"embedding_sparse_grad"
,
{
"padding_idx"
},
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
{
GradVarName
(
"W"
)});
{
"padding_idx"
},
{
GradVarName
(
"W"
)});
}
else
{
return
KernelSignature
(
"embedding_grad"
,
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
{
"padding_idx"
},
{
GradVarName
(
"W"
)});
}
}
else
{
}
else
{
return
KernelSignature
(
"sparse_weight_embedding_grad"
,
if
((
paddle
::
any_cast
<
bool
>
(
ctx
.
Attr
(
"is_sparse"
)))
==
true
)
{
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
return
KernelSignature
(
"sparse_weight_embedding_sparse_grad"
,
{
"padding_idx"
},
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
{
GradVarName
(
"W"
)});
{
"padding_idx"
},
{
GradVarName
(
"W"
)});
}
else
{
return
KernelSignature
(
"sparse_weight_embedding_grad"
,
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
{
"padding_idx"
},
{
GradVarName
(
"W"
)});
}
}
}
}
}
}
// namespace phi
}
// namespace phi
P
T
_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2
,
embedding
);
P
D
_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2
,
embedding
);
P
T
_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2_grad
,
embedding_grad
);
P
D
_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2_grad
,
embedding_grad
);
P
T
_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2
,
phi
::
EmbeddingOpArgumentMapping
);
P
D
_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2
,
phi
::
EmbeddingOpArgumentMapping
);
P
T
_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2_grad
,
P
D
_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2_grad
,
phi
::
EmbeddingGradOpArgumentMapping
);
phi
::
EmbeddingGradOpArgumentMapping
);
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
浏览文件 @
e81773c9
...
@@ -25,23 +25,24 @@ import paddle.compat as cpt
...
@@ -25,23 +25,24 @@ import paddle.compat as cpt
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
from
paddle.fluid
import
Program
,
program_guard
# class TestStaticGraphSupportMultipleInt(unittest.TestCase):
# def test_main(self):
class
TestStaticGraphSupportMultipleInt
(
unittest
.
TestCase
):
# dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
def
test_main
(
self
):
# if paddle.in_dynamic_mode():
dtypes
=
[
'uint8'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
]
# paddle.enable_static()
if
paddle
.
in_dynamic_mode
():
# disable_static = True
paddle
.
enable_static
()
# else:
disable_static
=
True
# disable_static = False
else
:
# for i, dtype in enumerate(dtypes):
disable_static
=
False
# with paddle.static.program_guard(paddle.static.Program(),
for
i
,
dtype
in
enumerate
(
dtypes
):
# paddle.static.Program()):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
# x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
paddle
.
static
.
Program
()):
# emb = paddle.nn.Embedding(10, 20)
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
7
,
30
],
dtype
=
dtype
)
# y = emb(x)
emb
=
paddle
.
nn
.
Embedding
(
10
,
20
)
y
=
emb
(
x
)
# if disable_static:
# paddle.disable_static()
if
disable_static
:
paddle
.
disable_static
()
class
TestLookupTableOp
(
OpTest
):
class
TestLookupTableOp
(
OpTest
):
...
@@ -62,17 +63,19 @@ class TestLookupTableOp(OpTest):
...
@@ -62,17 +63,19 @@ class TestLookupTableOp(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
#
class TestLookupTableOpInt16(OpTest):
class
TestLookupTableOpInt16
(
OpTest
):
#
def id_dtype(self):
def
id_dtype
(
self
):
#
return "int16"
return
"int16"
# class TestLookupTableOpInt8(OpTest):
# def id_dtype(self):
# return "int8"
# class TestLookupTableOpUInt8(OpTest):
class
TestLookupTableOpInt8
(
OpTest
):
# def id_dtype(self):
def
id_dtype
(
self
):
# return "uint8"
return
"int8"
class
TestLookupTableOpUInt8
(
OpTest
):
def
id_dtype
(
self
):
return
"uint8"
class
TestLookupTableOpWithTensorIds
(
OpTest
):
class
TestLookupTableOpWithTensorIds
(
OpTest
):
...
@@ -90,183 +93,190 @@ class TestLookupTableOpWithTensorIds(OpTest):
...
@@ -90,183 +93,190 @@ class TestLookupTableOpWithTensorIds(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
# @skip_check_grad_ci(
@
skip_check_grad_ci
(
# reason="Since paddings are not trainable and fixed in forward,"
reason
=
"Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't "
"the gradient of paddings makes no sense and we don't "
# "test the gradient here.")
"test the gradient here."
)
# class TestLookupTableOpWithPadding(TestLookupTableOp):
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
# def test_check_output(self):
def
test_check_output
(
self
):
# ids = np.squeeze(self.inputs['Ids'])
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
# padding_idx = np.random.choice(ids, 1)[0]
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
# self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
# self.attrs = {'padding_idx': int(padding_idx)}
self
.
attrs
=
{
'padding_idx'
:
int
(
padding_idx
)}
# self.check_output()
self
.
check_output
()
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward,"
@
skip_check_grad_ci
(
# "the gradient of paddings makes no sense and we don't "
reason
=
"Since paddings are not trainable and fixed in forward,"
# "test the gradient here.")
"the gradient of paddings makes no sense and we don't "
# class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
"test the gradient here."
)
# def test_check_output(self):
class
TestLookupTableOpWithTensorIdsAndPadding
(
TestLookupTableOpWithTensorIds
):
# ids = self.inputs['Ids']
def
test_check_output
(
self
):
# flatten_idx = ids.flatten()
ids
=
self
.
inputs
[
'Ids'
]
# padding_idx = np.random.choice(flatten_idx, 1)[0]
flatten_idx
=
ids
.
flatten
()
# self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
padding_idx
=
np
.
random
.
choice
(
flatten_idx
,
1
)[
0
]
# self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self
.
outputs
[
'Out'
][
np
.
squeeze
(
ids
==
padding_idx
)]
=
np
.
zeros
(
31
)
# self.check_output()
self
.
attrs
=
{
'padding_idx'
:
cpt
.
long_type
(
padding_idx
)}
self
.
check_output
()
# class TestLookupTableWIsSelectedRows(unittest.TestCase):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
class
TestLookupTableWIsSelectedRows
(
unittest
.
TestCase
):
# ids_array = np.array([0, 4, 3, 5]).astype("int32")
def
prepare_ids
(
self
,
scope
,
place
):
# ids_tensor.set(ids_array, place)
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
# return ids_array
ids_array
=
np
.
array
([
0
,
4
,
3
,
5
]).
astype
(
"int32"
)
ids_tensor
.
set
(
ids_array
,
place
)
# def prepare_w(self, scope, place):
return
ids_array
# rows = [0, 1, 2, 3, 4, 5, 6]
# row_numel = 12
def
prepare_w
(
self
,
scope
,
place
):
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
# w_selected_rows = scope.var('W').get_selected_rows()
row_numel
=
12
# w_selected_rows.set_height(len(rows))
# w_selected_rows.set_rows(rows)
w_selected_rows
=
scope
.
var
(
'W'
).
get_selected_rows
()
# w_array = np.ones((len(rows), row_numel)).astype("float32")
w_selected_rows
.
set_height
(
len
(
rows
))
# for i in range(len(rows)):
w_selected_rows
.
set_rows
(
rows
)
# w_array[i] *= i
w_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
# w_tensor = w_selected_rows.get_tensor()
for
i
in
range
(
len
(
rows
)):
# w_tensor.set(w_array, place)
w_array
[
i
]
*=
i
w_tensor
=
w_selected_rows
.
get_tensor
()
# def create_out_tensor(self, scope, place):
w_tensor
.
set
(
w_array
,
place
)
# return scope.var('Out').get_tensor()
def
create_out_tensor
(
self
,
scope
,
place
):
# def check_result(self, ids_array, result_array):
return
scope
.
var
(
'Out'
).
get_tensor
()
# # all(): return True if all elements of the iterable are true (or if the iterable is empty)
# for idx, row in enumerate(ids_array):
def
check_result
(
self
,
ids_array
,
result_array
):
# assert (row == result_array[idx]).all()
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for
idx
,
row
in
enumerate
(
ids_array
):
# def check_with_place(self, place):
assert
(
row
==
result_array
[
idx
]).
all
()
# scope = core.Scope()
def
check_with_place
(
self
,
place
):
# ids_array = self.prepare_ids(scope, place)
scope
=
core
.
Scope
()
# self.prepare_w(scope, place)
ids_array
=
self
.
prepare_ids
(
scope
,
place
)
# out_tensor = self.create_out_tensor(scope, place)
self
.
prepare_w
(
scope
,
place
)
# # create and run lookup_table operator
out_tensor
=
self
.
create_out_tensor
(
scope
,
place
)
# lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
# lookup_table.run(scope, place)
# create and run lookup_table operator
lookup_table
=
Operator
(
"lookup_table_v2"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
)
# # get result from Out
lookup_table
.
run
(
scope
,
place
)
# result_array = np.array(out_tensor)
# get result from Out
# self.check_result(ids_array, result_array)
result_array
=
np
.
array
(
out_tensor
)
# def test_w_is_selected_rows(self):
self
.
check_result
(
ids_array
,
result_array
)
# places = [core.CPUPlace()]
# # currently only support CPU
def
test_w_is_selected_rows
(
self
):
# for place in places:
places
=
[
core
.
CPUPlace
()]
# self.check_with_place(place)
# currently only support CPU
for
place
in
places
:
# class TestLookupTableWithTensorIdsWIsSelectedRows(
self
.
check_with_place
(
place
)
# TestLookupTableWIsSelectedRows):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
class
TestLookupTableWithTensorIdsWIsSelectedRows
(
# ids_array = np.random.randint(
TestLookupTableWIsSelectedRows
):
# low=0, high=6, size=(2, 4, 3)).astype("int64")
def
prepare_ids
(
self
,
scope
,
place
):
# ids_tensor.set(ids_array, place)
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
# return ids_array
ids_array
=
np
.
random
.
randint
(
low
=
0
,
high
=
6
,
size
=
(
2
,
4
,
3
)).
astype
(
"int64"
)
# def check_result(self, ids_array, result_array):
ids_tensor
.
set
(
ids_array
,
place
)
# for idx, row in np.ndenumerate(ids_array):
return
ids_array
# assert (row == result_array[idx]).all()
def
check_result
(
self
,
ids_array
,
result_array
):
# class TestLookupTableIsSparse(unittest.TestCase):
for
idx
,
row
in
np
.
ndenumerate
(
ids_array
):
# def init_data(self):
assert
(
row
==
result_array
[
idx
]).
all
()
# self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
# self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
class
TestLookupTableIsSparse
(
unittest
.
TestCase
):
# def get_w_grad(self, is_sparse):
def
init_data
(
self
):
# self.init_data()
self
.
x_data
=
np
.
array
([[
1
,
3
,
0
,
4
,
7
]]).
astype
(
"int64"
)
# main_program = fluid.Program()
self
.
y_data
=
np
.
array
([[
0.1
,
0.3
,
0
,
0.4
,
0.7
]]).
astype
(
"float32"
)
# with fluid.program_guard(main_program, fluid.Program()):
# x = fluid.layers.data(name='x', shape=[5], dtype='int64')
def
get_w_grad
(
self
,
is_sparse
):
# y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
self
.
init_data
()
# emb = fluid.input.embedding(
main_program
=
fluid
.
Program
()
# input=x,
with
fluid
.
program_guard
(
main_program
,
fluid
.
Program
()):
# size=[10, 16],
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
5
],
dtype
=
'int64'
)
# param_attr=fluid.ParamAttr(
y_
=
fluid
.
layers
.
data
(
name
=
'y_'
,
shape
=
[
5
],
dtype
=
'float32'
)
# name="emb_weight",
emb
=
fluid
.
input
.
embedding
(
# learning_rate=10,
input
=
x
,
# initializer=fluid.initializer.NumpyArrayInitializer(
size
=
[
10
,
16
],
# self.w_data)),
param_attr
=
fluid
.
ParamAttr
(
# is_sparse=is_sparse)
name
=
"emb_weight"
,
# y = fluid.layers.reduce_sum(emb, dim=-1)
learning_rate
=
10
,
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
# loss = fluid.layers.square_error_cost(input=y, label=y_)
self
.
w_data
)),
# loss = fluid.layers.mean(loss)
is_sparse
=
is_sparse
)
y
=
fluid
.
layers
.
reduce_sum
(
emb
,
dim
=-
1
)
# sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
# sgd_optimizer.minimize(loss)
loss
=
fluid
.
layers
.
square_error_cost
(
input
=
y
,
label
=
y_
)
loss
=
fluid
.
layers
.
mean
(
loss
)
# place = fluid.CPUPlace()
# exe = fluid.Executor(place)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-4
)
# exe.run(fluid.default_startup_program())
sgd_optimizer
.
minimize
(
loss
)
# ret = exe.run(feed={'x': self.x_data,
# 'y_': self.y_data},
place
=
fluid
.
CPUPlace
()
# fetch_list=['emb_weight'],
exe
=
fluid
.
Executor
(
place
)
# return_numpy=False)
exe
.
run
(
fluid
.
default_startup_program
())
# return np.array(ret[0])
ret
=
exe
.
run
(
feed
=
{
'x'
:
self
.
x_data
,
'y_'
:
self
.
y_data
},
# def test_w_grad(self):
fetch_list
=
[
'emb_weight'
],
# self.w_data = np.random.random(size=(10, 16)).astype("float32")
return_numpy
=
False
)
# w_grad = self.get_w_grad(False)
return
np
.
array
(
ret
[
0
])
# w_grad_with_sparse = self.get_w_grad(True)
# self.check_grad(w_grad, w_grad_with_sparse)
def
test_w_grad
(
self
):
self
.
w_data
=
np
.
random
.
random
(
size
=
(
10
,
16
)).
astype
(
"float32"
)
# def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
w_grad
=
self
.
get_w_grad
(
False
)
# np.testing.assert_allclose(
w_grad_with_sparse
=
self
.
get_w_grad
(
True
)
# w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
self
.
check_grad
(
w_grad
,
w_grad_with_sparse
)
# class TestLookupTableApi(unittest.TestCase):
def
check_grad
(
self
,
w_grad1
,
w_grad2
,
tolerance
=
1e-6
):
# def test_api(self):
np
.
testing
.
assert_allclose
(
# x = fluid.layers.data(name='x', shape=[20], dtype='int64')
w_grad1
,
w_grad2
,
rtol
=
tolerance
,
atol
=
tolerance
)
# emb = fluid.embedding(input=x, size=[128, 64])
# place = fluid.CPUPlace()
class
TestLookupTableApi
(
unittest
.
TestCase
):
# x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
def
test_api
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
20
],
dtype
=
'int64'
)
# exe = fluid.Executor(place)
emb
=
fluid
.
embedding
(
input
=
x
,
size
=
[
128
,
64
])
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': x_data, },
place
=
fluid
.
CPUPlace
()
# fetch_list=[emb],
x_data
=
np
.
random
.
randint
(
0
,
127
,
[
2
,
20
]).
astype
(
"int64"
)
# return_numpy=False)
exe
=
fluid
.
Executor
(
place
)
# class TestEmbedOpError(unittest.TestCase):
exe
.
run
(
fluid
.
default_startup_program
())
# def test_errors(self):
ret
=
exe
.
run
(
feed
=
{
'x'
:
x_data
,
},
# with program_guard(Program(), Program()):
fetch_list
=
[
emb
],
# input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
return_numpy
=
False
)
# def test_Variable():
# # the input type must be Variable
class
TestEmbedOpError
(
unittest
.
TestCase
):
# fluid.embedding(input=input_data, size=(10, 64))
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
# self.assertRaises(TypeError, test_Variable)
input_data
=
np
.
random
.
randint
(
0
,
10
,
(
4
,
6
)).
astype
(
"int64"
)
# def test_input_dtype():
def
test_Variable
():
# # the input dtype must be int64
# the input type must be Variable
# input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
fluid
.
embedding
(
input
=
input_data
,
size
=
(
10
,
64
))
# fluid.embedding(input=input, size=(10, 64))
self
.
assertRaises
(
TypeError
,
test_Variable
)
# self.assertRaises(TypeError, test_input_dtype)
def
test_input_dtype
():
# def test_param_dtype():
# the input dtype must be int64
# # dtype must be float32 or float64
input
=
fluid
.
data
(
name
=
'x1'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
# input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid
.
embedding
(
input
=
input
,
size
=
(
10
,
64
))
# fluid.embedding(input=input2, size=(10, 64), dtype='int64')
self
.
assertRaises
(
TypeError
,
test_input_dtype
)
# self.assertRaises(TypeError, test_param_dtype)
# input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
def
test_param_dtype
():
# fluid.embedding(input=input3, size=(10, 64), dtype='float16')
# dtype must be float32 or float64
input2
=
fluid
.
data
(
name
=
'x2'
,
shape
=
[
4
,
6
],
dtype
=
'int64'
)
fluid
.
embedding
(
input
=
input2
,
size
=
(
10
,
64
),
dtype
=
'int64'
)
self
.
assertRaises
(
TypeError
,
test_param_dtype
)
input3
=
fluid
.
data
(
name
=
'x3'
,
shape
=
[
4
,
6
],
dtype
=
'int64'
)
fluid
.
embedding
(
input
=
input3
,
size
=
(
10
,
64
),
dtype
=
'float16'
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
paddle
.
enable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录