未验证 提交 2a5adc5a 编写于 作者: C csy0225 提交者: GitHub

[XPU] Add embedding plugin (#56488)

上级 7c5f4fde
......@@ -239,6 +239,8 @@ if(WITH_XPU)
pass_library(cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(cast_embedding_trans_ids_to_int32_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(conv1d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_unsqueeze_squeeze_elimination_pass inference DIR xpu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct CastEmbeddingTransIdsToInt32Pattern : public PatternBase {
CastEmbeddingTransIdsToInt32Pattern(PDPattern* pattern,
const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast);
PATTERN_DECL_NODE(embedding);
// declare variable node's name
PATTERN_DECL_NODE(cast_x);
PATTERN_DECL_NODE(embedding_ids);
PATTERN_DECL_NODE(embedding_w);
PATTERN_DECL_NODE(embedding_out);
};
CastEmbeddingTransIdsToInt32Pattern::CastEmbeddingTransIdsToInt32Pattern(
PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto cast = pattern->NewNode(cast_repr())->assert_is_op("cast");
auto cast_x = pattern->NewNode(cast_x_repr())
->assert_is_op_input("cast", "X")
->assert_var_not_persistable()
->AsInput();
auto embedding_ids = pattern->NewNode(embedding_ids_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("lookup_table_v2", "Ids")
->assert_has_n_outputs(1);
cast->LinksFrom({cast_x}).LinksTo({embedding_ids});
auto embedding_w = pattern->NewNode(embedding_w_repr())
->assert_is_op_input("lookup_table_v2", "W");
auto embedding =
pattern->NewNode(embedding_repr())->assert_is_op("lookup_table_v2");
auto embedding_out = pattern->NewNode(embedding_out_repr())
->assert_is_op_output("lookup_table_v2", "Out")
->AsOutput();
embedding->LinksFrom({embedding_ids, embedding_w}).LinksTo({embedding_out});
}
} // namespace patterns
class CastEmbeddingTransIdsToInt32Pass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"cast_embedding_trans_ids_to_int32_pass"};
};
void CastEmbeddingTransIdsToInt32Pass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::CastEmbeddingTransIdsToInt32Pattern pattern(gpd.mutable_pattern(),
name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle CastEmbeddingTransIdsToInt32Pass";
GET_IR_NODE(cast);
GET_IR_NODE(embedding);
GET_IR_NODE(embedding_ids);
auto cast_node_attr_out_dtype =
cast->Op()->GetAttrIfExists<int>("out_dtype");
if (cast_node_attr_out_dtype !=
static_cast<int>(paddle::framework::proto::VarType::INT64)) {
return;
}
cast->Op()->SetAttr(
"out_dtype",
static_cast<int>(paddle::framework::proto::VarType::INT32));
embedding_ids->Var()->SetDataType(paddle::framework::proto::VarType::INT32);
embedding->Op()->Flush();
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
if (found_subgraph_count) {
VLOG(4) << "There is a risk of overflow when converting the data type of "
"embedded ids from int64 to int32."
"Please ensure that the numerical range of ids is within the "
"maximum value of int32."
"If it exceeds this range, it may result in incorrect results. "
"You can try removing this pass.";
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(cast_embedding_trans_ids_to_int32_pass,
paddle::framework::ir::CastEmbeddingTransIdsToInt32Pass);
REGISTER_PASS_CAPABILITY(cast_embedding_trans_ids_to_int32_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"lookup_table_v2", 1));
......@@ -516,6 +516,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass",
"constant_folding_pass",
"cast_embedding_trans_ids_to_int32_pass",
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
......
......@@ -44,18 +44,6 @@ void EmbeddingKernel(const Context &ctx,
auto *table = table_t->data<T>();
auto *output = dev_ctx.template Alloc<T>(output_t);
xpu::ctx_guard RAII_GUARD(ctx.x_context());
const int64_t *ids;
if (ids_t->dtype() == phi::DataType::INT64) {
ids = ids_t->data<int64_t>();
} else {
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
int r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
ids = reinterpret_cast<const int64_t *>(ids_tt);
}
PADDLE_ENFORCE_EQ(
ids_numel <= std::numeric_limits<int32_t>::max(),
true,
......@@ -68,15 +56,57 @@ void EmbeddingKernel(const Context &ctx,
size_t xm = table_t->dims()[0];
size_t n = table_t->dims()[1];
int r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
int r;
xpu::ctx_guard RAII_GUARD(ctx.x_context());
if (ids_t->dtype() == phi::DataType::INT64) {
#ifndef PADDLE_WITH_XPU_PLUGIN
r = xpu::embedding<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::fast_embedding<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int64_t>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
} else {
#ifndef PADDLE_WITH_XPU_PLUGIN
int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm<int64_t>(ids_t->numel());
r = xpu::cast<int32_t, int64_t>(
ctx.x_context(), ids_t->data<int>(), ids_tt, ids_t->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
const int64_t *ids = reinterpret_cast<const int64_t *>(ids_tt);
r = xpu::embedding<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids,
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#else
r = xpu::plugin::fast_embedding<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(table),
ids_t->data<int>(),
reinterpret_cast<XPUType *>(output),
xm,
n,
ym,
padding_idx);
#endif
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
}
......
......@@ -104,6 +104,17 @@ DLL_EXPORT int fast_reduce_min(Context* ctx,
const std::vector<int>& xshape,
const std::vector<int>& rdims);
template <typename T, typename TID>
DLL_EXPORT int fast_embedding(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index = 0);
} // namespace plugin
} // namespace api
} // namespace xpu
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
#include "xpu/kernel/xtdk_io.h"
namespace xpu2 {
namespace plugin {
/*
Kernel usage conditions: Dict is tiny, Local memory can be loaded in at once.
Optimizer ideas:
- Reduce frequent memory handling, allocate fixed size buffers, accumulate
data to buffer size and move it out together.
********** Local Memory Addr **********
Part 1: dict(size = dict_idx_len * emb_dim)
-----------------------------------
Part 2: index(size = idx_len * sizeof(emb_idx_type))
-----------------------------------
Part 3: result
-----------------------------------
*/
template <typename emb_idx_type>
static inline __device__ void embedding_fwd_kl2_tiny_dict_align64(
_global_ptr_ const emb_idx_type* idx,
_global_ptr_ const char* dict,
_global_ptr_ char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t row_start = -1;
int64_t row_end = -1;
partition(tid, nthreads, idx_len, 1, &row_start, &row_end);
// 1. Pre allocation total Local Memory size = 6 KB
const int TOTAL_LM_SIZE = 6144; // 6 KB
__simd__ char lm[TOTAL_LM_SIZE];
// 2. Load dict from Global Memory to Local memory only once.
int total_emb_dict_size = dict_idx_len * emb_dim;
GM2LM(dict, lm, total_emb_dict_size);
// residual_lm_space = index + result
int residual_lm_space = TOTAL_LM_SIZE - total_emb_dict_size -
64; // 64 to preventing memory overflow, because the
// total index memory need to align to 64.
// The maximum count that can be processed in one iteration.
int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim);
int index_lm_offset = total_emb_dict_size;
int result_lm_offset =
total_emb_dict_size +
(idx_cnt * sizeof(emb_idx_type) + 64) / 64 * 64; // Align to 64 bytes
// 3. Loop Calc
for (int64_t i = row_start; i < row_end; i += idx_cnt) {
int curr_idx_len = idx_cnt;
if (i + idx_cnt >= row_end) {
curr_idx_len = row_end - i;
}
// 3.1 Load idx to Local Memory
GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type));
// 3.2 Save result into result memory buffer.
for (int j = 0; j < curr_idx_len; j++) {
emb_idx_type real_index =
*((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) -
start_index;
if (real_index == padding_idx) {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0);
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
} else {
if (real_index >= 0 && real_index < dict_idx_len) {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16(
(void*)(lm + real_index * emb_dim + koffset));
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
} else {
for (int koffset = 0; koffset < emb_dim; koffset += 64) {
float32x16_t v_src = vload_lm_float32x16_mz((void*)lm, 0);
vstore_lm_float32x16(
(void*)(lm + result_lm_offset + j * emb_dim + koffset), v_src);
}
}
}
mfence_lm();
}
// 3.3 Save result into global memory buffer.
LM2GM(lm + result_lm_offset,
(_global_ptr_ char*)(featvec + i * emb_dim),
curr_idx_len * emb_dim);
}
}
template <typename emb_idx_type>
static inline __device__ void embedding_fwd_kl2_tiny_dict_not_align64(
_global_ptr_ const emb_idx_type* idx,
_global_ptr_ const char* dict,
_global_ptr_ char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
int cid = core_id();
int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = ncores * cluster_num();
int64_t row_start = -1;
int64_t row_end = -1;
partition(tid, nthreads, idx_len, 1, &row_start, &row_end);
// 1. Pre allocation total Local Memory size = 6 KB
const int TOTAL_LM_SIZE = 6144; // 6 KB
__local__ char lm[TOTAL_LM_SIZE];
// 2. Load dict from Global Memory to Local memory only once.
GM2LM(dict, lm, dict_idx_len * emb_dim);
// residual_lm_space = index + result
int residual_lm_space = TOTAL_LM_SIZE - dict_idx_len * emb_dim;
// The maximum count that can be processed in one iteration.
int idx_cnt = residual_lm_space / (sizeof(emb_idx_type) + emb_dim);
int index_lm_offset = dict_idx_len * emb_dim;
int result_lm_offset = index_lm_offset + idx_cnt * sizeof(emb_idx_type);
// 3. Loop Calc
for (int64_t i = row_start; i < row_end; i += idx_cnt) {
int curr_idx_len = idx_cnt;
if (i + idx_cnt >= row_end) {
curr_idx_len = row_end - i;
}
// 3.1 Load idx to Local Memory
GM2LM(idx + i, lm + index_lm_offset, curr_idx_len * sizeof(emb_idx_type));
// 3.2 Save result into result memory buffer.
for (int j = 0; j < curr_idx_len; j++) {
emb_idx_type real_index =
*((emb_idx_type*)(lm + index_lm_offset + j * sizeof(emb_idx_type))) -
start_index;
if (real_index == padding_idx) {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] = 0;
}
} else {
if (real_index >= 0 && real_index < dict_idx_len) {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] =
lm[real_index * emb_dim + k];
}
} else {
for (int k = 0; k < emb_dim; k++) {
lm[result_lm_offset + j * emb_dim + k] = 0;
}
}
}
mfence_lm();
}
// 3.3 Save result into global memory buffer.
LM2GM(lm + result_lm_offset,
(_global_ptr_ char*)(featvec + i * emb_dim),
curr_idx_len * emb_dim);
}
}
template <typename emb_idx_type>
__global__ void embedding_fwd_kl2_tiny_dict(const emb_idx_type* idx,
const char* dict,
char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index) {
if (emb_dim % 64 == 0) {
embedding_fwd_kl2_tiny_dict_align64<emb_idx_type>(idx,
dict,
featvec,
emb_dim,
dict_idx_len,
idx_len,
padding_idx,
start_index);
} else {
embedding_fwd_kl2_tiny_dict_not_align64<emb_idx_type>(idx,
dict,
featvec,
emb_dim,
dict_idx_len,
idx_len,
padding_idx,
start_index);
}
}
#define _XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(EMB_IDX_TYPE) \
template __global__ void embedding_fwd_kl2_tiny_dict<EMB_IDX_TYPE>( \
const EMB_IDX_TYPE* idx, \
const char* dict, \
char* featvec, \
int64_t emb_dim, \
int64_t dict_idx_len, \
int64_t idx_len, \
int64_t padding_idx, \
EMB_IDX_TYPE start_index);
_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int);
_XPU_DEF__EMBEDDING_FWD_KL2_TINY_DICT_(int64_t);
} // namespace plugin
} // namespace xpu2
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace xpu2 {
namespace plugin {
template <typename emb_idx_type>
__attribute__((global)) void embedding_fwd_kl2_tiny_dict(
const emb_idx_type* idx,
const char* dict,
char* featvec,
int64_t emb_dim,
int64_t dict_idx_len,
int64_t idx_len,
int64_t padding_idx,
emb_idx_type start_index);
} // namespace plugin
} // namespace xpu2
namespace baidu {
namespace xpu {
namespace api {
namespace plugin {
// CPU implementation
template <typename T, typename TID>
static int cpu_wrapper(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index) {
for (int64_t i = 0; i < ym; i++) {
TID real_index = indices[i] - start_index; // -start_index BEFORE compare
if (real_index == padding_idx) {
::memset(y + i * n, 0, sizeof(T) * n);
} else {
if (real_index >= 0 && real_index < xm) {
std::memcpy(y + i * n, x + real_index * n, sizeof(T) * n);
} else {
// set zeros
for (int64_t k = 0; k < n; ++k) {
y[i * n + k] = 0;
}
}
}
}
return api::SUCCESS;
}
template <typename T, typename TID>
static int xpu2_wrapper(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index) {
const int TOTAL_LM_SIZE = 6144; // 6 KB
int total_emb_dict_size = xm * n * sizeof(T);
// residual_lm_space = index + result
int residual_lm_space = TOTAL_LM_SIZE - total_emb_dict_size - 64;
// The maximum count that can be processed in one iteration.
int idx_cnt = residual_lm_space / (sizeof(TID) + n * sizeof(T));
bool plugin_entry_condition = idx_cnt >= 16;
// This plugin is suitable for scenarios with relatively small dictionary
// sizes, requiring process greater than 16 index count one iter, in order to
// load the dictionary into local memory at once, and to leave enough space
// for the local memory to store the results.
if (plugin_entry_condition) {
using XPU_TID = typename XPUIndexType<TID>::type;
const XPU_TID* casted_indices =
static_cast<const XPU_TID*>(static_cast<const void*>(indices));
XPU_TID casted_start_index = static_cast<XPU_TID>(start_index);
if (ctx->dev().type() == api::kXPU2) {
xpu2::plugin::embedding_fwd_kl2_tiny_dict<XPU_TID>
<<<ctx->ncluster(), 64, ctx->xpu_stream>>>(
casted_indices,
reinterpret_cast<const char*>(x),
reinterpret_cast<char*>(y),
n * sizeof(T),
xm,
ym,
padding_idx,
casted_start_index);
}
} else {
embedding<T, TID>(ctx, x, indices, y, xm, n, ym, padding_idx, start_index);
}
return api::SUCCESS;
}
template <typename T, typename TID>
int fast_embedding(Context* ctx,
const T* x,
const TID* indices,
T* y,
int64_t xm,
int64_t n,
int64_t ym,
int64_t padding_idx,
TID start_index) {
WRAPPER_CHECK_CTX(ctx);
WRAPPER_DUMP_FUNCTION_T2(ctx, "fast_embedding", T, TID);
WRAPPER_DUMP_PARAM6(ctx, x, indices, y, xm, n, ym);
WRAPPER_DUMP_PARAM3(ctx, padding_idx, start_index, ctx->_l3_mgr.get_size());
WRAPPER_DUMP(ctx);
int64_t xlen = -1;
int64_t ylen = -1;
WRAPPER_CHECK_SHAPE(ctx, &xlen, {xm, n});
WRAPPER_CHECK_SHAPE(ctx, &ylen, {ym, n});
WRAPPER_CHECK_PTR(ctx, T, xlen, x);
WRAPPER_CHECK_PTR(ctx, T, ylen, y);
WRAPPER_CHECK_PTR(ctx, TID, ym, indices);
if (ctx->dev().type() == api::kCPU) {
return cpu_wrapper<T>(
ctx, x, indices, y, xm, n, ym, padding_idx, start_index);
}
if (ctx->dev().type() == api::kXPU2) {
return xpu2_wrapper<T, TID>(
ctx, x, indices, y, xm, n, ym, padding_idx, start_index);
}
WRAPPER_UNIMPLEMENTED(ctx);
}
template int fast_embedding(Context*,
const float*,
const int*,
float*,
int64_t,
int64_t,
int64_t,
int64_t,
int);
template int fast_embedding(Context*,
const float*,
const int64_t*,
float*,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t);
template int fast_embedding(Context*,
const float16*,
const int*,
float16*,
int64_t,
int64_t,
int64_t,
int64_t,
int);
template int fast_embedding(Context*,
const float16*,
const int64_t*,
float16*,
int64_t,
int64_t,
int64_t,
int64_t,
int64_t);
} // namespace plugin
} // namespace api
} // namespace xpu
} // namespace baidu
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestXpuCastEmbeddingTransIdsToInt32Pass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["cast", "lookup_table_v2"], (1e-5, 1e-5)
def sample_program_config(self, draw):
ids_shape = draw(st.integers(min_value=1, max_value=128))
w_shape = draw(
st.sampled_from([[20, 64], [32, 32], [23, 15], [24, 33]])
)
padding_idx = draw(st.sampled_from([-1]))
cast_op = OpConfig(
"cast",
inputs={
"X": ["cast_input"],
},
outputs={"Out": ["cast_out"]},
in_dtype=5,
out_dtype=3,
)
lookup_table_op = OpConfig(
"lookup_table_v2",
inputs={
"Ids": ["cast_out"],
"W": ["lookup_table_w"],
},
outputs={"Out": ["lookup_table_out"]},
padding_idx=padding_idx,
)
def gen_lookup_table_weights_data():
weights = {}
w_name = "lookup_table_w"
weights[w_name] = TensorConfig(shape=w_shape)
return weights
def generate_cast_input(*args, **kwargs):
return np.random.randint(0, w_shape[0], ids_shape).astype(
np.float32
)
def gen_input_data(*args, **kwargs):
inputs = {}
input_name = "cast_input"
inputs[input_name] = TensorConfig(
data_gen=partial(generate_cast_input)
)
return inputs
inputs = gen_input_data()
weights = gen_lookup_table_weights_data()
program_config = ProgramConfig(
ops=[cast_op, lookup_table_op],
weights=weights,
inputs=inputs,
outputs=["lookup_table_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["cast_embedding_trans_ids_to_int32_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册