Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
2a5adc5a
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2a5adc5a
编写于
8月 24, 2023
作者:
C
csy0225
提交者:
GitHub
8月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add embedding plugin (#56488)
上级
7c5f4fde
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
725 addition
and
21 deletion
+725
-21
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc
...ramework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc
+137
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/kernels/xpu/embedding_kernel.cc
paddle/phi/kernels/xpu/embedding_kernel.cc
+51
-21
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+11
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu
.../plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu
+240
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp
+189
-0
test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py
...erence/test_xpu_cast_embedding_trans_ids_to_int32_pass.py
+94
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
2a5adc5a
...
...
@@ -239,6 +239,8 @@ if(WITH_XPU)
pass_library
(
cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
yolo_box_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv1d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv2d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
...
...
paddle/fluid/framework/ir/xpu/cast_embedding_trans_ids_to_int32_pass.cc
0 → 100644
浏览文件 @
2a5adc5a
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
struct
CastEmbeddingTransIdsToInt32Pattern
:
public
PatternBase
{
CastEmbeddingTransIdsToInt32Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
cast
);
PATTERN_DECL_NODE
(
embedding
);
// declare variable node's name
PATTERN_DECL_NODE
(
cast_x
);
PATTERN_DECL_NODE
(
embedding_ids
);
PATTERN_DECL_NODE
(
embedding_w
);
PATTERN_DECL_NODE
(
embedding_out
);
};
CastEmbeddingTransIdsToInt32Pattern
::
CastEmbeddingTransIdsToInt32Pattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
cast
=
pattern
->
NewNode
(
cast_repr
())
->
assert_is_op
(
"cast"
);
auto
cast_x
=
pattern
->
NewNode
(
cast_x_repr
())
->
assert_is_op_input
(
"cast"
,
"X"
)
->
assert_var_not_persistable
()
->
AsInput
();
auto
embedding_ids
=
pattern
->
NewNode
(
embedding_ids_repr
())
->
assert_is_op_output
(
"cast"
,
"Out"
)
->
assert_is_op_input
(
"lookup_table_v2"
,
"Ids"
)
->
assert_has_n_outputs
(
1
);
cast
->
LinksFrom
({
cast_x
}).
LinksTo
({
embedding_ids
});
auto
embedding_w
=
pattern
->
NewNode
(
embedding_w_repr
())
->
assert_is_op_input
(
"lookup_table_v2"
,
"W"
);
auto
embedding
=
pattern
->
NewNode
(
embedding_repr
())
->
assert_is_op
(
"lookup_table_v2"
);
auto
embedding_out
=
pattern
->
NewNode
(
embedding_out_repr
())
->
assert_is_op_output
(
"lookup_table_v2"
,
"Out"
)
->
AsOutput
();
embedding
->
LinksFrom
({
embedding_ids
,
embedding_w
}).
LinksTo
({
embedding_out
});
}
}
// namespace patterns
class
CastEmbeddingTransIdsToInt32Pass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
const
std
::
string
name_scope_
{
"cast_embedding_trans_ids_to_int32_pass"
};
};
void
CastEmbeddingTransIdsToInt32Pass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
GraphPatternDetector
gpd
;
patterns
::
CastEmbeddingTransIdsToInt32Pattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle CastEmbeddingTransIdsToInt32Pass"
;
GET_IR_NODE
(
cast
);
GET_IR_NODE
(
embedding
);
GET_IR_NODE
(
embedding_ids
);
auto
cast_node_attr_out_dtype
=
cast
->
Op
()
->
GetAttrIfExists
<
int
>
(
"out_dtype"
);
if
(
cast_node_attr_out_dtype
!=
static_cast
<
int
>
(
paddle
::
framework
::
proto
::
VarType
::
INT64
))
{
return
;
}
cast
->
Op
()
->
SetAttr
(
"out_dtype"
,
static_cast
<
int
>
(
paddle
::
framework
::
proto
::
VarType
::
INT32
));
embedding_ids
->
Var
()
->
SetDataType
(
paddle
::
framework
::
proto
::
VarType
::
INT32
);
embedding
->
Op
()
->
Flush
();
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
if
(
found_subgraph_count
)
{
VLOG
(
4
)
<<
"There is a risk of overflow when converting the data type of "
"embedded ids from int64 to int32."
"Please ensure that the numerical range of ids is within the "
"maximum value of int32."
"If it exceeds this range, it may result in incorrect results. "
"You can try removing this pass."
;
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
cast_embedding_trans_ids_to_int32_pass
,
paddle
::
framework
::
ir
::
CastEmbeddingTransIdsToInt32Pass
);
REGISTER_PASS_CAPABILITY
(
cast_embedding_trans_ids_to_int32_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
LE
(
"lookup_table_v2"
,
1
));
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
2a5adc5a
...
...
@@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"reshape_unstack_concat_fuse_pass"
,
"delete_op_device_pass"
,
"constant_folding_pass"
,
"cast_embedding_trans_ids_to_int32_pass"
,
"delete_elementwise_mul_op_pass"
,
"generate_sequence_xpu_fuse_pass"
,
"embedding_with_eltwise_add_xpu_fuse_pass"
,
...
...
paddle/phi/kernels/xpu/embedding_kernel.cc
浏览文件 @
2a5adc5a
...
...
@@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx,
auto
*
table
=
table_t
->
data
<
T
>
();
auto
*
output
=
dev_ctx
.
template
Alloc
<
T
>(
output_t
);
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
const
int64_t
*
ids
;
if
(
ids_t
->
dtype
()
==
phi
::
DataType
::
INT64
)
{
ids
=
ids_t
->
data
<
int64_t
>
();
}
else
{
int64_t
*
ids_tt
=
RAII_GUARD
.
alloc_l3_or_gm
<
int64_t
>
(
ids_t
->
numel
());
int
r
=
xpu
::
cast
<
int32_t
,
int64_t
>
(
ctx
.
x_context
(),
ids_t
->
data
<
int
>
(),
ids_tt
,
ids_t
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"cast"
);
ids
=
reinterpret_cast
<
const
int64_t
*>
(
ids_tt
);
}
PADDLE_ENFORCE_EQ
(
ids_numel
<=
std
::
numeric_limits
<
int32_t
>::
max
(),
true
,
...
...
@@ -68,7 +56,38 @@ void EmbeddingKernel(const Context &ctx,
size_t
xm
=
table_t
->
dims
()[
0
];
size_t
n
=
table_t
->
dims
()[
1
];
int
r
=
xpu
::
embedding
<
XPUType
>
(
dev_ctx
.
x_context
(),
int
r
;
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
if
(
ids_t
->
dtype
()
==
phi
::
DataType
::
INT64
)
{
#ifndef PADDLE_WITH_XPU_PLUGIN
r
=
xpu
::
embedding
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
table
),
ids_t
->
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
output
),
xm
,
n
,
ym
,
padding_idx
);
#else
r
=
xpu
::
plugin
::
fast_embedding
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
table
),
ids_t
->
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
output
),
xm
,
n
,
ym
,
padding_idx
);
#endif
}
else
{
#ifndef PADDLE_WITH_XPU_PLUGIN
int64_t
*
ids_tt
=
RAII_GUARD
.
alloc_l3_or_gm
<
int64_t
>
(
ids_t
->
numel
());
r
=
xpu
::
cast
<
int32_t
,
int64_t
>
(
ctx
.
x_context
(),
ids_t
->
data
<
int
>
(),
ids_tt
,
ids_t
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"cast"
);
const
int64_t
*
ids
=
reinterpret_cast
<
const
int64_t
*>
(
ids_tt
);
r
=
xpu
::
embedding
<
XPUType
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
table
),
ids
,
reinterpret_cast
<
XPUType
*>
(
output
),
...
...
@@ -76,7 +95,18 @@ void EmbeddingKernel(const Context &ctx,
n
,
ym
,
padding_idx
);
#else
r
=
xpu
::
plugin
::
fast_embedding
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
table
),
ids_t
->
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
output
),
xm
,
n
,
ym
,
padding_idx
);
#endif
}
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"embedding"
);
}
...
...
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
2a5adc5a
...
...
@@ -104,6 +104,17 @@ DLL_EXPORT int fast_reduce_min(Context* ctx,
const
std
::
vector
<
int
>&
xshape
,
const
std
::
vector
<
int
>&
rdims
);
template
<
typename
T
,
typename
TID
>
DLL_EXPORT
int
fast_embedding
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
indices
,
T
*
y
,
int64_t
xm
,
int64_t
n
,
int64_t
ym
,
int64_t
padding_idx
,
TID
start_index
=
0
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/embedding_fwd_tiny_dict.xpu
0 → 100644
浏览文件 @
2a5adc5a
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu2 {
namespace plugin {
/*
Kernel usage conditions: Dict is tiny, Local memory can be loaded in at once.
Optimizer ideas:
- Reduce frequent memory handling, allocate fixed size buffers, accumulate
data to buffer size and move it out together.
********** Local Memory Addr **********
Part 1: dict(size = dict_idx_len * emb_dim)
-----------------------------------
Part 2: index(size = idx_len * sizeof(emb_idx_type))
-----------------------------------
Part 3: result
-----------------------------------
*/
template <typename emb_idx_type>
static inline __device__ void embedding_fwd_kl2_tiny_dict_align64(
_global_ptr_ const emb_idx_type* idx,
_global_ptr_ const char* dict,
_global_ptr_ char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t row_start = -1;
int64_t row_end = -1;
partition(tid, nthreads, idx_len, 1, &row_start, &row_end);
// 1. Pre allocation total Local Memory size = 6 KB
const int TOTAL_LM_SIZE = 6144; // 6 KB
__simd__ char lm[TOTAL_LM_SIZE];
// 2. Load dict from Global Memory to Local memory only once.
int total_emb_dict_size = dict_idx_len * emb_dim;
GM2LM(dict, lm, total_emb_dict_size);
// residual_lm_space = index + result
int residual_lm_space = TOTAL_LM_SIZE - total_emb_dict_size -
64; // 64 to preventing memory overflow, because the
// total index memory need to align to 64.
// The maximum count that can be processed in one iteration.
int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim);
int index_lm_offset = total_emb_dict_size;
int result_lm_offset =
total_emb_dict_size +
(idx_cnt * sizeof(emb_idx_type) + 64) / 64 * 64; // Align to 64 bytes
// 3. Loop Calc
for (int64_t i = row_start; i < row_end; i += idx_cnt) {
int curr_idx_len = idx_cnt;
if (i + idx_cnt >= row_end) {
curr_idx_len = row_end - i;
}
// 3.1 Load idx to Local Memory
GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type));
// 3.2 Save result into result memory buffer.
for (int j = 0; j < curr_idx_len; j++) {
emb_idx_type real_index =
*((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) -
start_index;
if (real_index == padding_idx) {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0);
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
} else {
if (real_index >= 0 && real_index < dict_idx_len) {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16(
(void*)(lm + real_index * emb_dim + koffset));
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
} else {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0);
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
}
}
mfence_lm();
}
// 3.3 Save result into global memory buffer.
LM2GM(lm + result_lm_offset,
(_global_ptr_ char*)(featvec + i * emb_dim),
curr_idx_len * emb_dim);
}
}
template <typename emb_idx_type>
static inline __device__ void embedding_fwd_kl2_tiny_dict_not_align64(
_global_ptr_ const emb_idx_type* idx,
_global_ptr_ const char* dict,
_global_ptr_ char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t row_start = -1;
int64_t row_end = -1;
partition(tid, nthreads, idx_len, 1, &row_start, &row_end);
// 1. Pre allocation total Local Memory size = 6 KB
const int TOTAL_LM_SIZE = 6144; // 6 KB
__local__ char lm[TOTAL_LM_SIZE];
// 2. Load dict from Global Memory to Local memory only once.
GM2LM(dict, lm, dict_idx_len * emb_dim);
// residual_lm_space = index + result
int residual_lm_space = TOTAL_LM_SIZE - dict_idx_len * emb_dim;
// The maximum count that can be processed in one iteration.
int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim);
int index_lm_offset = dict_idx_len * emb_dim;
int result_lm_offset = index_lm_offset + idx_cnt * sizeof(emb_idx_type);
// 3. Loop Calc
for (int64_t i = row_start; i < row_end; i += idx_cnt) {
int curr_idx_len = idx_cnt;
if (i + idx_cnt >= row_end) {
curr_idx_len = row_end - i;
}
// 3.1 Load idx to Local Memory
GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type));
// 3.2 Save result into result memory buffer.
for (int j = 0; j < curr_idx_len; j++) {
emb_idx_type real_index =
*((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) -
start_index;
if (real_index == padding_idx) {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] = 0;
}
} else {
if (real_index >= 0 && real_index < dict_idx_len) {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] =
lm[real_index * emb_dim + k];
}
} else {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] = 0;
}
}
}
mfence_lm();
}
// 3.3 Save result into global memory buffer.
LM2GM(lm + result_lm_offset,
(_global_ptr_ char*)(featvec + i * emb_dim),
curr_idx_len * emb_dim);
}
}
template <typename emb_idx_type>
__global__ void embedding_fwd_kl2_tiny_dict(const emb_idx_type* idx,
const char* dict,
char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
if (emb_dim % 64 == 0) {
embedding_fwd_kl2_tiny_dict_align64<emb_idx_type>(idx,
dict,
featvec,
emb_dim,
dict_idx_len,
idx_len,
padding_idx,
start_index);
} else {
embedding_fwd_kl2_tiny_dict_not_align64<emb_idx_type>(idx,
dict,
featvec,
emb_dim,
dict_idx_len,
idx_len,
padding_idx,
start_index);
}
}
#define _XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(EMB_IDX_TYPE) \
template __global__ void embedding_fwd_kl2_tiny_dict<EMB_IDX_TYPE>( \
const EMB_IDX_TYPE* idx, \
const char* dict, \
char* featvec, \
int64_t emb_dim, \
int64_t dict_idx_len, \
int64_t idx_len, \
int64_t padding_idx, \
EMB_IDX_TYPE start_index);
_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int);
_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int64_t);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_embedding.cpp
0 → 100644
浏览文件 @
2a5adc5a
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
emb_idx_type
>
__attribute__
((
global
))
void
embedding_fwd_kl2_tiny_dict
(
const
emb_idx_type
*
idx
,
const
char
*
dict
,
char
*
featvec
,
int64_t
emb_dim
,
int64_t
dict_idx_len
,
int64_t
idx_len
,
int64_t
padding_idx
,
emb_idx_type
start_index
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
// CPU implementation
template
<
typename
T
,
typename
TID
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
indices
,
T
*
y
,
int64_t
xm
,
int64_t
n
,
int64_t
ym
,
int64_t
padding_idx
,
TID
start_index
)
{
for
(
int64_t
i
=
0
;
i
<
ym
;
i
++
)
{
TID
real_index
=
indices
[
i
]
-
start_index
;
// -start_index BEFORE compare
if
(
real_index
==
padding_idx
)
{
::
memset
(
y
+
i
*
n
,
0
,
sizeof
(
T
)
*
n
);
}
else
{
if
(
real_index
>=
0
&&
real_index
<
xm
)
{
std
::
memcpy
(
y
+
i
*
n
,
x
+
real_index
*
n
,
sizeof
(
T
)
*
n
);
}
else
{
// set zeros
for
(
int64_t
k
=
0
;
k
<
n
;
++
k
)
{
y
[
i
*
n
+
k
]
=
0
;
}
}
}
}
return
api
::
SUCCESS
;
}
template
<
typename
T
,
typename
TID
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
indices
,
T
*
y
,
int64_t
xm
,
int64_t
n
,
int64_t
ym
,
int64_t
padding_idx
,
TID
start_index
)
{
const
int
TOTAL_LM_SIZE
=
6144
;
// 6 KB
int
total_emb_dict_size
=
xm
*
n
*
sizeof
(
T
);
// residual_lm_space = index + result
int
residual_lm_space
=
TOTAL_LM_SIZE
-
total_emb_dict_size
-
64
;
// The maximum count that can be processed in one iteration.
int
idx_cnt
=
residual_lm_space
/
(
sizeof
(
TID
)
+
n
*
sizeof
(
T
));
bool
plugin_entry_condition
=
idx_cnt
>=
16
;
// This plugin is suitable for scenarios with relatively small dictionary
// sizes, requiring process greater than 16 index count one iter, in order to
// load the dictionary into local memory at once, and to leave enough space
// for the local memory to store the results.
if
(
plugin_entry_condition
)
{
using
XPU_TID
=
typename
XPUIndexType
<
TID
>::
type
;
const
XPU_TID
*
casted_indices
=
static_cast
<
const
XPU_TID
*>
(
static_cast
<
const
void
*>
(
indices
));
XPU_TID
casted_start_index
=
static_cast
<
XPU_TID
>
(
start_index
);
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
xpu2
::
plugin
::
embedding_fwd_kl2_tiny_dict
<
XPU_TID
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
casted_indices
,
reinterpret_cast
<
const
char
*>
(
x
),
reinterpret_cast
<
char
*>
(
y
),
n
*
sizeof
(
T
),
xm
,
ym
,
padding_idx
,
casted_start_index
);
}
}
else
{
embedding
<
T
,
TID
>
(
ctx
,
x
,
indices
,
y
,
xm
,
n
,
ym
,
padding_idx
,
start_index
);
}
return
api
::
SUCCESS
;
}
template
<
typename
T
,
typename
TID
>
int
fast_embedding
(
Context
*
ctx
,
const
T
*
x
,
const
TID
*
indices
,
T
*
y
,
int64_t
xm
,
int64_t
n
,
int64_t
ym
,
int64_t
padding_idx
,
TID
start_index
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T2
(
ctx
,
"fast_embedding"
,
T
,
TID
);
WRAPPER_DUMP_PARAM6
(
ctx
,
x
,
indices
,
y
,
xm
,
n
,
ym
);
WRAPPER_DUMP_PARAM3
(
ctx
,
padding_idx
,
start_index
,
ctx
->
_l3_mgr
.
get_size
());
WRAPPER_DUMP
(
ctx
);
int64_t
xlen
=
-
1
;
int64_t
ylen
=
-
1
;
WRAPPER_CHECK_SHAPE
(
ctx
,
&
xlen
,
{
xm
,
n
});
WRAPPER_CHECK_SHAPE
(
ctx
,
&
ylen
,
{
ym
,
n
});
WRAPPER_CHECK_PTR
(
ctx
,
T
,
xlen
,
x
);
WRAPPER_CHECK_PTR
(
ctx
,
T
,
ylen
,
y
);
WRAPPER_CHECK_PTR
(
ctx
,
TID
,
ym
,
indices
);
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
>
(
ctx
,
x
,
indices
,
y
,
xm
,
n
,
ym
,
padding_idx
,
start_index
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
,
TID
>
(
ctx
,
x
,
indices
,
y
,
xm
,
n
,
ym
,
padding_idx
,
start_index
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_embedding
(
Context
*
,
const
float
*
,
const
int
*
,
float
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int
);
template
int
fast_embedding
(
Context
*
,
const
float
*
,
const
int64_t
*
,
float
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
);
template
int
fast_embedding
(
Context
*
,
const
float16
*
,
const
int
*
,
float16
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int
);
template
int
fast_embedding
(
Context
*
,
const
float16
*
,
const
int64_t
*
,
float16
*
,
int64_t
,
int64_t
,
int64_t
,
int64_t
,
int64_t
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
test/ir/inference/test_xpu_cast_embedding_trans_ids_to_int32_pass.py
0 → 100644
浏览文件 @
2a5adc5a
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestXpuCastEmbeddingTransIdsToInt32Pass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"cast"
,
"lookup_table_v2"
],
(
1e-5
,
1e-5
)
def
sample_program_config
(
self
,
draw
):
ids_shape
=
draw
(
st
.
integers
(
min_value
=
1
,
max_value
=
128
))
w_shape
=
draw
(
st
.
sampled_from
([[
20
,
64
],
[
32
,
32
],
[
23
,
15
],
[
24
,
33
]])
)
padding_idx
=
draw
(
st
.
sampled_from
([
-
1
]))
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"cast_input"
],
},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
5
,
out_dtype
=
3
,
)
lookup_table_op
=
OpConfig
(
"lookup_table_v2"
,
inputs
=
{
"Ids"
:
[
"cast_out"
],
"W"
:
[
"lookup_table_w"
],
},
outputs
=
{
"Out"
:
[
"lookup_table_out"
]},
padding_idx
=
padding_idx
,
)
def
gen_lookup_table_weights_data
():
weights
=
{}
w_name
=
"lookup_table_w"
weights
[
w_name
]
=
TensorConfig
(
shape
=
w_shape
)
return
weights
def
generate_cast_input
(
*
args
,
**
kwargs
):
return
np
.
random
.
randint
(
0
,
w_shape
[
0
],
ids_shape
).
astype
(
np
.
float32
)
def
gen_input_data
(
*
args
,
**
kwargs
):
inputs
=
{}
input_name
=
"cast_input"
inputs
[
input_name
]
=
TensorConfig
(
data_gen
=
partial
(
generate_cast_input
)
)
return
inputs
inputs
=
gen_input_data
()
weights
=
gen_lookup_table_weights_data
()
program_config
=
ProgramConfig
(
ops
=
[
cast_op
,
lookup_table_op
],
weights
=
weights
,
inputs
=
inputs
,
outputs
=
[
"lookup_table_out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"cast_embedding_trans_ids_to_int32_pass"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录