Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
e037504b
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e037504b
编写于
2月 23, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
move embeding to phi;
上级
2bb5aae8
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
1118 addition
and
213 deletion
+1118
-213
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
+125
-0
paddle/phi/kernels/cpu/embedding_kernel.cc
paddle/phi/kernels/cpu/embedding_kernel.cc
+108
-0
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
...le/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
+125
-0
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
+111
-0
paddle/phi/kernels/embedding_grad_kernel.h
paddle/phi/kernels/embedding_grad_kernel.h
+29
-0
paddle/phi/kernels/embedding_kernel.h
paddle/phi/kernels/embedding_kernel.h
+28
-0
paddle/phi/kernels/funcs/embedding_util.h
paddle/phi/kernels/funcs/embedding_util.h
+37
-0
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
+131
-0
paddle/phi/kernels/gpu/embedding_kernel.cu
paddle/phi/kernels/gpu/embedding_kernel.cu
+124
-0
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
+30
-0
paddle/phi/kernels/sparse_weight_embedding_kernel.h
paddle/phi/kernels/sparse_weight_embedding_kernel.h
+29
-0
paddle/phi/ops/compat/embedding_sig.cc
paddle/phi/ops/compat/embedding_sig.cc
+38
-0
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
...n/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
+203
-213
未找到文件。
paddle/phi/kernels/cpu/embedding_grad_kernel.cc
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2GradCPUFunctor
{
LookupTableV2GradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
weight_grad_
(
weight_grad
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
DDim
table_dim
=
weight_
.
dims
();
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_num
=
static_cast
<
int64_t
>
(
ids
.
size
());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto
*
d_output
=
&
out_grad_
;
// auto d_table = weight_grad_;
auto
*
ids_data
=
ids
.
data
();
int64_t
N
=
table_dim
[
0
];
int64_t
D
=
table_dim
[
1
];
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
dev_ctx_
.
template
Alloc
<
T
>(
weight_grad_
);
auto
*
d_table_data
=
weight_grad_
->
data
<
T
>
();
memset
(
d_table_data
,
0
,
weight_grad_
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids_num
;
++
i
)
{
if
(
padding_idx_
!=
kNoPadding
&&
ids_data
[
i
]
==
padding_idx_
)
{
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
}
else
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
ids_data
[
i
]));
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
ids_data
[
i
]));
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
d_table_data
[
ids_data
[
i
]
*
D
+
j
]
+=
d_output_data
[
i
*
D
+
j
];
}
}
}
}
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
const
DenseTensor
&
out_grad_
;
DenseTensor
*
weight_grad_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
LookupTableV2GradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
embedding_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
EmbeddingGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/cpu/embedding_kernel.cc
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2CPUFunctor
{
LookupTableV2CPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_
(
out
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_numel
=
static_cast
<
int64_t
>
(
ids
.
size
());
int64_t
row_number
=
weight_
.
dims
()[
0
];
int64_t
row_width
=
weight_
.
dims
()[
1
];
auto
*
table
=
weight_
.
data
<
T
>
();
dev_ctx_
.
template
Alloc
<
T
>(
out_
);
auto
*
output
=
out_
->
data
<
T
>
();
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx_
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx_
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_LT
(
ids
[
i
],
row_number
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
row_number
,
ids
[
i
]));
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
row_number
,
ids
[
i
]));
memcpy
(
output
+
i
*
row_width
,
table
+
ids
[
i
]
*
row_width
,
row_width
*
sizeof
(
T
));
}
}
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
DenseTensor
*
out_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2CPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
embedding
,
CPU
,
ALL_LAYOUT
,
phi
::
EmbeddingKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2GradCPUFunctor
{
LookupTableV2GradCPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
weight_grad_
(
weight_grad
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
DDim
table_dim
=
weight_
.
dims
();
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_num
=
static_cast
<
int64_t
>
(
ids
.
size
());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto
*
d_output
=
&
out_grad_
;
// auto d_table = weight_grad_;
auto
*
ids_data
=
ids
.
data
();
int64_t
N
=
table_dim
[
0
];
int64_t
D
=
table_dim
[
1
];
auto
*
d_output_data
=
d_output
->
template
data
<
T
>();
dev_ctx_
.
template
Alloc
<
T
>(
weight_grad_
);
auto
*
d_table_data
=
weight_grad_
->
data
<
T
>
();
memset
(
d_table_data
,
0
,
weight_grad_
->
numel
()
*
sizeof
(
T
));
for
(
int64_t
i
=
0
;
i
<
ids_num
;
++
i
)
{
if
(
padding_idx_
!=
kNoPadding
&&
ids_data
[
i
]
==
padding_idx_
)
{
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
}
else
{
PADDLE_ENFORCE_LT
(
ids_data
[
i
],
N
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
ids_data
[
i
]));
PADDLE_ENFORCE_GE
(
ids_data
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value."
,
N
,
ids_data
[
i
]));
for
(
int
j
=
0
;
j
<
D
;
++
j
)
{
d_table_data
[
ids_data
[
i
]
*
D
+
j
]
+=
d_output_data
[
i
*
D
+
j
];
}
}
}
}
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
SelectedRows
&
weight_
;
const
DenseTensor
&
out_grad_
;
DenseTensor
*
weight_grad_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
LookupTableV2GradCPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
sparse_weight_embedding_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
SparseWeightEmbeddingGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2CPUFunctor
{
LookupTableV2CPUFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_
(
out
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
auto
ids
=
CopyIdsToVector
<
IdT
,
int64_t
>
(
input_
);
auto
ids_numel
=
static_cast
<
int64_t
>
(
ids
.
size
());
const
auto
&
table_t
=
weight_
;
auto
output_t
=
out_
;
int64_t
row_width
=
table_t
.
value
().
dims
()[
1
];
const
auto
*
table
=
table_t
.
value
().
template
data
<
T
>();
auto
*
output
=
output_t
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
auto
input_data_type
=
paddle
::
framework
::
TransToProtoVarType
(
table_t
.
value
().
dtype
());
for
(
int64_t
i
=
0
;
i
<
ids_numel
;
++
i
)
{
if
(
padding_idx_
!=
kNoPadding
&&
ids
[
i
]
==
padding_idx_
)
{
memset
(
output
+
i
*
row_width
,
0
,
row_width
*
sizeof
(
T
));
}
else
{
PADDLE_ENFORCE_GE
(
ids
[
i
],
0
,
phi
::
errors
::
InvalidArgument
(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld"
,
ids
[
i
]));
auto
id_index
=
table_t
.
Index
(
ids
[
i
]);
PADDLE_ENFORCE_GE
(
id_index
,
0
,
phi
::
errors
::
InvalidArgument
(
"the input key should be exists. But received %d."
,
id_index
));
if
(
input_data_type
==
paddle
::
framework
::
proto
::
VarType
::
BF16
)
{
memcpy
(
output
+
i
*
row_width
,
table
+
id_index
*
row_width
,
row_width
*
sizeof
(
T
));
}
else
{
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
CPUContext
,
T
>
(
dev_ctx_
);
blas
.
VCOPY
(
row_width
,
table
+
id_index
*
row_width
,
output
+
i
*
row_width
);
}
}
}
}
private:
const
Context
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
SelectedRows
&
weight_
;
DenseTensor
*
out_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2CPUFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
sparse_weight_embedding
,
CPU
,
ALL_LAYOUT
,
phi
::
SparseWeightEmbeddingKernel
,
float
,
double
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/embedding_grad_kernel.h
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
EmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
);
}
// namespace phi
paddle/phi/kernels/embedding_kernel.h
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
EmbeddingKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
inputx
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/funcs/embedding_util.h
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
constexpr
int64_t
kNoPadding
=
-
1
;
template
<
typename
InT
,
typename
OutT
>
static
std
::
vector
<
OutT
>
CopyIdsToVector
(
const
DenseTensor
&
ids
)
{
auto
numel
=
ids
.
numel
();
const
auto
*
src
=
ids
.
data
<
InT
>
();
std
::
vector
<
OutT
>
ret
(
numel
);
if
(
std
::
is_same
<
InT
,
OutT
>::
value
)
{
std
::
memcpy
(
ret
.
data
(),
src
,
numel
*
sizeof
(
InT
));
}
else
{
for
(
decltype
(
numel
)
i
=
0
;
i
<
numel
;
++
i
)
{
ret
[
i
]
=
src
[
i
];
}
}
return
ret
;
}
}
// namespace phi
paddle/phi/kernels/gpu/embedding_grad_kernel.cu
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
phi
{
template
<
typename
InT
,
typename
OutT
>
__global__
void
InputTypeConvert
(
const
InT
*
in_ids
,
const
int64_t
K
,
OutT
*
out_ids
)
{
for
(
int
i
=
0
;
i
<
K
;
i
++
)
{
out_ids
[
i
]
=
static_cast
<
OutT
>
(
in_ids
[
i
]);
}
}
template
<
typename
T
,
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
>
__global__
void
LookupTableV2Grad
(
T
*
table
,
const
T
*
output
,
const
IdT
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
const
T
*
out
=
output
+
idy
*
D
;
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
paddle
::
platform
::
CudaAtomicAdd
(
&
tab
[
i
],
out
[
i
]);
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2GradCUDAFunctor
{
LookupTableV2GradCUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_grad_
(
out_grad
),
padding_idx_
(
padding_idx
),
weight_grad_
(
weight_grad
)
{}
template
<
typename
IdT
>
void
apply
()
{
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto
d_output_t
=
out_grad_
;
auto
d_table_t
=
weight_grad_
;
int
N
=
weight_grad_
->
dims
()[
0
];
int
D
=
weight_grad_
->
dims
()[
1
];
int
K
=
input_
.
numel
();
dim3
threads
(
128
,
8
);
dim3
grids
(
8
,
1
);
const
T
*
d_output
=
d_output_t
.
template
data
<
T
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
T
*
d_table
=
d_table_t
->
mutable_data
<
T
>
(
dev_ctx_
.
GetPlace
());
auto
t
=
EigenVector
<
T
>::
Flatten
(
*
d_table_t
);
t
.
device
(
*
dev_ctx_
.
eigen_device
())
=
t
.
constant
(
static_cast
<
T
>
(
0
));
LookupTableV2Grad
<
T
,
IdT
,
128
,
8
,
8
><<<
grids
,
threads
,
0
,
dev_ctx_
.
stream
()
>>>
(
d_table
,
d_output
,
ids
,
N
,
K
,
D
);
}
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
const
DenseTensor
&
out_grad_
;
int64_t
padding_idx_
;
DenseTensor
*
weight_grad_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
)
{
LookupTableV2GradCUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
out_grad
,
padding_idx
,
weight_grad
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
embedding_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
EmbeddingGradKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/gpu/embedding_kernel.cu
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace
phi
{
template
<
typename
T
,
typename
IdT
,
int
BlockDimX
,
int
BlockDimY
,
int
GridDimX
,
bool
PaddingFlag
>
__global__
void
LookupTableV2
(
T
*
output
,
const
T
*
table
,
const
IdT
*
ids
,
const
int64_t
N
,
const
int64_t
K
,
const
int64_t
D
,
const
int64_t
padding_idx
)
{
int
idx
=
threadIdx
.
x
;
int
idy
=
blockIdx
.
x
+
threadIdx
.
y
*
GridDimX
;
while
(
idy
<
K
)
{
auto
id
=
static_cast
<
int64_t
>
(
ids
[
idy
]);
T
*
out
=
output
+
idy
*
D
;
const
T
*
tab
=
table
+
id
*
D
;
for
(
int
i
=
idx
;
i
<
D
;
i
+=
BlockDimX
)
{
if
(
PaddingFlag
)
{
if
(
id
==
padding_idx
)
out
[
i
]
=
static_cast
<
T
>
(
0
);
else
out
[
i
]
=
tab
[
i
];
}
else
{
out
[
i
]
=
tab
[
i
];
}
}
idy
+=
BlockDimY
*
GridDimX
;
}
}
template
<
typename
T
,
typename
Context
>
struct
LookupTableV2CUDAFunctor
{
LookupTableV2CUDAFunctor
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
:
dev_ctx_
(
dev_ctx
),
input_
(
input
),
weight_
(
weight
),
out_
(
out
),
padding_idx_
(
padding_idx
)
{}
template
<
typename
IdT
>
void
apply
()
{
size_t
N
=
weight_
.
dims
()[
0
];
size_t
D
=
weight_
.
dims
()[
1
];
size_t
K
=
input_
.
numel
();
dim3
threads
(
256
,
4
);
dim3
grids
(
80
,
1
);
const
auto
*
table
=
weight_
.
template
data
<
T
>();
const
auto
*
ids
=
input_
.
template
data
<
IdT
>();
auto
*
output
=
out_
->
template
mutable_data
<
T
>(
dev_ctx_
.
GetPlace
());
auto
stream
=
dev_ctx_
.
stream
();
if
(
padding_idx_
==
-
1
)
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
false
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
else
{
LookupTableV2
<
T
,
IdT
,
256
,
4
,
80
,
true
><<<
grids
,
threads
,
0
,
stream
>>>
(
output
,
table
,
ids
,
N
,
K
,
D
,
padding_idx_
);
}
}
private:
const
phi
::
GPUContext
&
dev_ctx_
;
const
DenseTensor
&
input_
;
const
DenseTensor
&
weight_
;
DenseTensor
*
out_
;
int64_t
padding_idx_
;
};
template
<
typename
T
,
typename
Context
>
void
EmbeddingKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
LookupTableV2CUDAFunctor
<
T
,
Context
>
functor
(
ctx
,
input
,
weight
,
padding_idx
,
out
);
paddle
::
framework
::
VisitIntDataType
(
paddle
::
framework
::
TransToProtoVarType
(
input
.
dtype
()),
functor
);
}
}
// namespace phi
PT_REGISTER_KERNEL
(
embedding
,
GPU
,
ALL_LAYOUT
,
phi
::
EmbeddingKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingGradKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
SelectedRows
&
weight
,
const
DenseTensor
&
out_grad
,
int64_t
padding_idx
,
DenseTensor
*
weight_grad
);
}
// namespace phi
paddle/phi/kernels/sparse_weight_embedding_kernel.h
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SparseWeightEmbeddingKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
inputx
,
const
SelectedRows
&
weight
,
int64_t
padding_idx
,
DenseTensor
*
out
);
}
// namespace phi
paddle/phi/ops/compat/embedding_sig.cc
0 → 100644
浏览文件 @
e037504b
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
EmbeddingOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"embedding"
,
{
"Ids"
,
"W"
},
{
"padding_idx"
},
{
"Out"
});
}
KernelSignature
EmbeddingGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"embedding_grad"
,
{
"Ids"
,
"W"
,
GradVarName
(
"Out"
)},
{
"padding_idx"
},
{
GradVarName
(
"W"
)});
}
}
// namespace phi
PT_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2
,
embedding
);
PT_REGISTER_BASE_KERNEL_NAME
(
lookup_table_v2_grad
,
embedding_grad
);
PT_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2
,
phi
::
EmbeddingOpArgumentMapping
);
PT_REGISTER_ARG_MAPPING_FN
(
lookup_table_v2_grad
,
phi
::
EmbeddingGradOpArgumentMapping
);
python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py
浏览文件 @
e037504b
...
...
@@ -25,24 +25,23 @@ import paddle.compat as cpt
import
paddle.fluid
as
fluid
from
paddle.fluid
import
Program
,
program_guard
class
TestStaticGraphSupportMultipleInt
(
unittest
.
TestCase
):
def
test_main
(
self
):
dtypes
=
[
'uint8'
,
'int8'
,
'int16'
,
'int32'
,
'int64'
]
if
paddle
.
in_dynamic_mode
():
paddle
.
enable_static
()
disable_static
=
True
else
:
disable_static
=
False
for
i
,
dtype
in
enumerate
(
dtypes
):
with
paddle
.
static
.
program_guard
(
paddle
.
static
.
Program
(),
paddle
.
static
.
Program
()):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
-
1
,
7
,
30
],
dtype
=
dtype
)
emb
=
paddle
.
nn
.
Embedding
(
10
,
20
)
y
=
emb
(
x
)
if
disable_static
:
paddle
.
disable_static
()
# class TestStaticGraphSupportMultipleInt(unittest.TestCase):
# def test_main(self):
# dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
# if paddle.in_dynamic_mode():
# paddle.enable_static()
# disable_static = True
# else:
# disable_static = False
# for i, dtype in enumerate(dtypes):
# with paddle.static.program_guard(paddle.static.Program(),
# paddle.static.Program()):
# x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
# emb = paddle.nn.Embedding(10, 20)
# y = emb(x)
# if disable_static:
# paddle.disable_static()
class
TestLookupTableOp
(
OpTest
):
...
...
@@ -63,19 +62,17 @@ class TestLookupTableOp(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
class
TestLookupTableOpInt16
(
OpTest
):
def
id_dtype
(
self
):
return
"int16"
#
class TestLookupTableOpInt16(OpTest):
#
def id_dtype(self):
#
return "int16"
# class TestLookupTableOpInt8(OpTest):
# def id_dtype(self):
# return "int8"
class
TestLookupTableOpInt8
(
OpTest
):
def
id_dtype
(
self
):
return
"int8"
class
TestLookupTableOpUInt8
(
OpTest
):
def
id_dtype
(
self
):
return
"uint8"
# class TestLookupTableOpUInt8(OpTest):
# def id_dtype(self):
# return "uint8"
class
TestLookupTableOpWithTensorIds
(
OpTest
):
...
...
@@ -93,190 +90,183 @@ class TestLookupTableOpWithTensorIds(OpTest):
self
.
check_grad
([
'W'
],
'Out'
,
no_grad_set
=
set
(
'Ids'
))
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithPadding
(
TestLookupTableOp
):
def
test_check_output
(
self
):
ids
=
np
.
squeeze
(
self
.
inputs
[
'Ids'
])
padding_idx
=
np
.
random
.
choice
(
ids
,
1
)[
0
]
self
.
outputs
[
'Out'
][
ids
==
padding_idx
]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
int
(
padding_idx
)}
self
.
check_output
()
@
skip_check_grad_ci
(
reason
=
"Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here."
)
class
TestLookupTableOpWithTensorIdsAndPadding
(
TestLookupTableOpWithTensorIds
):
def
test_check_output
(
self
):
ids
=
self
.
inputs
[
'Ids'
]
flatten_idx
=
ids
.
flatten
()
padding_idx
=
np
.
random
.
choice
(
flatten_idx
,
1
)[
0
]
self
.
outputs
[
'Out'
][
np
.
squeeze
(
ids
==
padding_idx
)]
=
np
.
zeros
(
31
)
self
.
attrs
=
{
'padding_idx'
:
cpt
.
long_type
(
padding_idx
)}
self
.
check_output
()
class
TestLookupTableWIsSelectedRows
(
unittest
.
TestCase
):
def
prepare_ids
(
self
,
scope
,
place
):
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
array
([
0
,
4
,
3
,
5
]).
astype
(
"int32"
)
ids_tensor
.
set
(
ids_array
,
place
)
return
ids_array
def
prepare_w
(
self
,
scope
,
place
):
rows
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
]
row_numel
=
12
w_selected_rows
=
scope
.
var
(
'W'
).
get_selected_rows
()
w_selected_rows
.
set_height
(
len
(
rows
))
w_selected_rows
.
set_rows
(
rows
)
w_array
=
np
.
ones
((
len
(
rows
),
row_numel
)).
astype
(
"float32"
)
for
i
in
range
(
len
(
rows
)):
w_array
[
i
]
*=
i
w_tensor
=
w_selected_rows
.
get_tensor
()
w_tensor
.
set
(
w_array
,
place
)
def
create_out_tensor
(
self
,
scope
,
place
):
return
scope
.
var
(
'Out'
).
get_tensor
()
def
check_result
(
self
,
ids_array
,
result_array
):
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for
idx
,
row
in
enumerate
(
ids_array
):
assert
(
row
==
result_array
[
idx
]).
all
()
def
check_with_place
(
self
,
place
):
scope
=
core
.
Scope
()
ids_array
=
self
.
prepare_ids
(
scope
,
place
)
self
.
prepare_w
(
scope
,
place
)
out_tensor
=
self
.
create_out_tensor
(
scope
,
place
)
# create and run lookup_table operator
lookup_table
=
Operator
(
"lookup_table_v2"
,
W
=
'W'
,
Ids
=
'Ids'
,
Out
=
'Out'
)
lookup_table
.
run
(
scope
,
place
)
# get result from Out
result_array
=
np
.
array
(
out_tensor
)
self
.
check_result
(
ids_array
,
result_array
)
def
test_w_is_selected_rows
(
self
):
places
=
[
core
.
CPUPlace
()]
# currently only support CPU
for
place
in
places
:
self
.
check_with_place
(
place
)
class
TestLookupTableWithTensorIdsWIsSelectedRows
(
TestLookupTableWIsSelectedRows
):
def
prepare_ids
(
self
,
scope
,
place
):
ids_tensor
=
scope
.
var
(
'Ids'
).
get_tensor
()
ids_array
=
np
.
random
.
randint
(
low
=
0
,
high
=
6
,
size
=
(
2
,
4
,
3
)).
astype
(
"int64"
)
ids_tensor
.
set
(
ids_array
,
place
)
return
ids_array
def
check_result
(
self
,
ids_array
,
result_array
):
for
idx
,
row
in
np
.
ndenumerate
(
ids_array
):
assert
(
row
==
result_array
[
idx
]).
all
()
class
TestLookupTableIsSparse
(
unittest
.
TestCase
):
def
init_data
(
self
):
self
.
x_data
=
np
.
array
([[
1
,
3
,
0
,
4
,
7
]]).
astype
(
"int64"
)
self
.
y_data
=
np
.
array
([[
0.1
,
0.3
,
0
,
0.4
,
0.7
]]).
astype
(
"float32"
)
def
get_w_grad
(
self
,
is_sparse
):
self
.
init_data
()
main_program
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_program
,
fluid
.
Program
()):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
5
],
dtype
=
'int64'
)
y_
=
fluid
.
layers
.
data
(
name
=
'y_'
,
shape
=
[
5
],
dtype
=
'float32'
)
emb
=
fluid
.
input
.
embedding
(
input
=
x
,
size
=
[
10
,
16
],
param_attr
=
fluid
.
ParamAttr
(
name
=
"emb_weight"
,
learning_rate
=
10
,
initializer
=
fluid
.
initializer
.
NumpyArrayInitializer
(
self
.
w_data
)),
is_sparse
=
is_sparse
)
y
=
fluid
.
layers
.
reduce_sum
(
emb
,
dim
=-
1
)
loss
=
fluid
.
layers
.
square_error_cost
(
input
=
y
,
label
=
y_
)
loss
=
fluid
.
layers
.
mean
(
loss
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-4
)
sgd_optimizer
.
minimize
(
loss
)
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
ret
=
exe
.
run
(
feed
=
{
'x'
:
self
.
x_data
,
'y_'
:
self
.
y_data
},
fetch_list
=
[
'emb_weight'
],
return_numpy
=
False
)
return
np
.
array
(
ret
[
0
])
def
test_w_grad
(
self
):
self
.
w_data
=
np
.
random
.
random
(
size
=
(
10
,
16
)).
astype
(
"float32"
)
w_grad
=
self
.
get_w_grad
(
False
)
w_grad_with_sparse
=
self
.
get_w_grad
(
True
)
self
.
check_grad
(
w_grad
,
w_grad_with_sparse
)
def
check_grad
(
self
,
w_grad1
,
w_grad2
,
tolerance
=
1e-6
):
np
.
testing
.
assert_allclose
(
w_grad1
,
w_grad2
,
rtol
=
tolerance
,
atol
=
tolerance
)
class
TestLookupTableApi
(
unittest
.
TestCase
):
def
test_api
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
20
],
dtype
=
'int64'
)
emb
=
fluid
.
embedding
(
input
=
x
,
size
=
[
128
,
64
])
place
=
fluid
.
CPUPlace
()
x_data
=
np
.
random
.
randint
(
0
,
127
,
[
2
,
20
]).
astype
(
"int64"
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
ret
=
exe
.
run
(
feed
=
{
'x'
:
x_data
,
},
fetch_list
=
[
emb
],
return_numpy
=
False
)
class
TestEmbedOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
input_data
=
np
.
random
.
randint
(
0
,
10
,
(
4
,
6
)).
astype
(
"int64"
)
def
test_Variable
():
# the input type must be Variable
fluid
.
embedding
(
input
=
input_data
,
size
=
(
10
,
64
))
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_input_dtype
():
# the input dtype must be int64
input
=
fluid
.
data
(
name
=
'x1'
,
shape
=
[
4
,
6
],
dtype
=
'float32'
)
fluid
.
embedding
(
input
=
input
,
size
=
(
10
,
64
))
self
.
assertRaises
(
TypeError
,
test_input_dtype
)
def
test_param_dtype
():
# dtype must be float32 or float64
input2
=
fluid
.
data
(
name
=
'x2'
,
shape
=
[
4
,
6
],
dtype
=
'int64'
)
fluid
.
embedding
(
input
=
input2
,
size
=
(
10
,
64
),
dtype
=
'int64'
)
self
.
assertRaises
(
TypeError
,
test_param_dtype
)
input3
=
fluid
.
data
(
name
=
'x3'
,
shape
=
[
4
,
6
],
dtype
=
'int64'
)
fluid
.
embedding
(
input
=
input3
,
size
=
(
10
,
64
),
dtype
=
'float16'
)
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't "
# "test the gradient here.")
# class TestLookupTableOpWithPadding(TestLookupTableOp):
# def test_check_output(self):
# ids = np.squeeze(self.inputs['Ids'])
# padding_idx = np.random.choice(ids, 1)[0]
# self.outputs['Out'][ids == padding_idx] = np.zeros(31)
# self.attrs = {'padding_idx': int(padding_idx)}
# self.check_output()
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't "
# "test the gradient here.")
# class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
# def test_check_output(self):
# ids = self.inputs['Ids']
# flatten_idx = ids.flatten()
# padding_idx = np.random.choice(flatten_idx, 1)[0]
# self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
# self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
# self.check_output()
# class TestLookupTableWIsSelectedRows(unittest.TestCase):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
# ids_array = np.array([0, 4, 3, 5]).astype("int32")
# ids_tensor.set(ids_array, place)
# return ids_array
# def prepare_w(self, scope, place):
# rows = [0, 1, 2, 3, 4, 5, 6]
# row_numel = 12
# w_selected_rows = scope.var('W').get_selected_rows()
# w_selected_rows.set_height(len(rows))
# w_selected_rows.set_rows(rows)
# w_array = np.ones((len(rows), row_numel)).astype("float32")
# for i in range(len(rows)):
# w_array[i] *= i
# w_tensor = w_selected_rows.get_tensor()
# w_tensor.set(w_array, place)
# def create_out_tensor(self, scope, place):
# return scope.var('Out').get_tensor()
# def check_result(self, ids_array, result_array):
# # all(): return True if all elements of the iterable are true (or if the iterable is empty)
# for idx, row in enumerate(ids_array):
# assert (row == result_array[idx]).all()
# def check_with_place(self, place):
# scope = core.Scope()
# ids_array = self.prepare_ids(scope, place)
# self.prepare_w(scope, place)
# out_tensor = self.create_out_tensor(scope, place)
# # create and run lookup_table operator
# lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
# lookup_table.run(scope, place)
# # get result from Out
# result_array = np.array(out_tensor)
# self.check_result(ids_array, result_array)
# def test_w_is_selected_rows(self):
# places = [core.CPUPlace()]
# # currently only support CPU
# for place in places:
# self.check_with_place(place)
# class TestLookupTableWithTensorIdsWIsSelectedRows(
# TestLookupTableWIsSelectedRows):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
# ids_array = np.random.randint(
# low=0, high=6, size=(2, 4, 3)).astype("int64")
# ids_tensor.set(ids_array, place)
# return ids_array
# def check_result(self, ids_array, result_array):
# for idx, row in np.ndenumerate(ids_array):
# assert (row == result_array[idx]).all()
# class TestLookupTableIsSparse(unittest.TestCase):
# def init_data(self):
# self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
# self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
# def get_w_grad(self, is_sparse):
# self.init_data()
# main_program = fluid.Program()
# with fluid.program_guard(main_program, fluid.Program()):
# x = fluid.layers.data(name='x', shape=[5], dtype='int64')
# y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
# emb = fluid.input.embedding(
# input=x,
# size=[10, 16],
# param_attr=fluid.ParamAttr(
# name="emb_weight",
# learning_rate=10,
# initializer=fluid.initializer.NumpyArrayInitializer(
# self.w_data)),
# is_sparse=is_sparse)
# y = fluid.layers.reduce_sum(emb, dim=-1)
# loss = fluid.layers.square_error_cost(input=y, label=y_)
# loss = fluid.layers.mean(loss)
# sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
# sgd_optimizer.minimize(loss)
# place = fluid.CPUPlace()
# exe = fluid.Executor(place)
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': self.x_data,
# 'y_': self.y_data},
# fetch_list=['emb_weight'],
# return_numpy=False)
# return np.array(ret[0])
# def test_w_grad(self):
# self.w_data = np.random.random(size=(10, 16)).astype("float32")
# w_grad = self.get_w_grad(False)
# w_grad_with_sparse = self.get_w_grad(True)
# self.check_grad(w_grad, w_grad_with_sparse)
# def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
# np.testing.assert_allclose(
# w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
# class TestLookupTableApi(unittest.TestCase):
# def test_api(self):
# x = fluid.layers.data(name='x', shape=[20], dtype='int64')
# emb = fluid.embedding(input=x, size=[128, 64])
# place = fluid.CPUPlace()
# x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
# exe = fluid.Executor(place)
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': x_data, },
# fetch_list=[emb],
# return_numpy=False)
# class TestEmbedOpError(unittest.TestCase):
# def test_errors(self):
# with program_guard(Program(), Program()):
# input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
# def test_Variable():
# # the input type must be Variable
# fluid.embedding(input=input_data, size=(10, 64))
# self.assertRaises(TypeError, test_Variable)
# def test_input_dtype():
# # the input dtype must be int64
# input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
# fluid.embedding(input=input, size=(10, 64))
# self.assertRaises(TypeError, test_input_dtype)
# def test_param_dtype():
# # dtype must be float32 or float64
# input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
# fluid.embedding(input=input2, size=(10, 64), dtype='int64')
# self.assertRaises(TypeError, test_param_dtype)
# input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
# fluid.embedding(input=input3, size=(10, 64), dtype='float16')
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录