未验证 提交 95aab366 编写于 作者: Z zhupengyang 提交者: GitHub

fix embedding_with_eltwise_add_xpu (#55354)

上级 fce4c2de
......@@ -19,20 +19,67 @@
namespace phi {
namespace fusion {
template <typename T, typename Context>
void EmbeddingWithEltwiseAddXpuKernel(
const Context& ctx,
const std::vector<const DenseTensor*>& ids,
const std::vector<const DenseTensor*>& tables,
const paddle::optional<DenseTensor>& mask,
int64_t padding_idx,
DenseTensor* out,
DenseTensor* seq_lod,
DenseTensor* max_seq_len) {
using XPUType = typename XPUTypeTrait<T>::Type;
int emb_dim = tables[0]->dims()[1];
std::vector<int> table_lens;
std::vector<const float*> arg_tables;
namespace {
template <typename T>
void FillSeqLod(int batch_size, int max_seq_len, const T* mask, int* seq_lod) {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cur_batch_seq_len = 0;
for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) {
int mask_idx = batch_idx * max_seq_len + seq_idx;
if (mask[mask_idx] > 0) {
cur_batch_seq_len++;
} else {
break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len;
}
}
template <>
void FillSeqLod<float>(int batch_size,
int max_seq_len,
const float* mask,
int* seq_lod) {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cur_batch_seq_len = 0;
for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) {
int mask_idx = batch_idx * max_seq_len + seq_idx;
if (fabs(mask[mask_idx]) > 1e-5) {
cur_batch_seq_len++;
} else {
break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len;
}
}
template <typename TT, typename TID, typename Context>
void MultiEmbeddingKernel(const Context& ctx,
const std::vector<const DenseTensor*>& ids,
const std::vector<const DenseTensor*>& tables,
const paddle::optional<DenseTensor>& mask,
int64_t padding_idx,
DenseTensor* out,
DenseTensor* seq_lod,
DenseTensor* max_seq_len) {
using XPUType = typename XPUTypeTrait<TT>::Type;
int64_t emb_dim = tables[0]->dims()[1];
std::vector<TID> table_lens;
std::vector<const XPUType*> arg_tables;
for (auto* table : tables) {
auto& table_dims = table->dims();
PADDLE_ENFORCE_EQ(
......@@ -49,20 +96,7 @@ void EmbeddingWithEltwiseAddXpuKernel(
table_dims[1],
emb_dim));
table_lens.push_back(table_dims[0]);
if (std::is_same<T, phi::dtype::float16>::value) {
DenseTensor table_data_fp32_t;
ctx.template Alloc<float>(&table_data_fp32_t,
table->numel() * sizeof(float));
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(table->data<T>()),
table_data_fp32_t.data<float>(),
table->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
arg_tables.push_back(table_data_fp32_t.data<float>());
} else {
arg_tables.push_back(table->data<float>());
}
arg_tables.push_back(reinterpret_cast<const XPUType*>(table->data<TT>()));
}
int emb_layer_num = ids.size();
......@@ -78,118 +112,81 @@ void EmbeddingWithEltwiseAddXpuKernel(
auto& id_dims = ids[0]->dims();
int batch_size = id_dims[0];
int max_seq_len_value = id_dims[1];
int ids_len = id_dims[0] * id_dims[1];
std::vector<std::vector<int>> int_ids(emb_layer_num,
std::vector<int>(ids_len, 0));
std::vector<xpu::VectorParam<int>> arg_ids;
auto* mask_tensor = mask.get_ptr();
if (mask_tensor != nullptr) {
auto mask_dtype = mask_tensor->dtype();
PADDLE_ENFORCE(
mask_dtype == phi::DataType::INT64 ||
mask_dtype == phi::DataType::FLOAT32,
errors::InvalidArgument(
"The data type of mask should be int64 or float32, but got %s.",
DataTypeToString(mask_dtype)));
max_seq_len->Resize({1});
ctx.template HostAlloc<int>(max_seq_len)[0] = max_seq_len_value;
seq_lod->Resize({batch_size + 1});
int* seq_lod_data = ctx.template HostAlloc<int>(seq_lod);
seq_lod_data[0] = 0;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cur_batch_seq_len = 0;
for (int seq_idx = 0; seq_idx < max_seq_len_value; seq_idx++) {
int mask_idx = batch_idx * max_seq_len_value + seq_idx;
if ((mask_dtype == phi::DataType::INT64 &&
mask->data<int64_t>()[mask_idx] > 0) ||
(mask_dtype == phi::DataType::FLOAT32 &&
fabs(mask->data<float>()[mask_idx]) > 1e-5)) {
cur_batch_seq_len++;
} else {
break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod_data[batch_idx + 1] = seq_lod_data[batch_idx] + cur_batch_seq_len;
switch (mask_tensor->dtype()) {
case DataType::FLOAT32:
FillSeqLod(batch_size,
max_seq_len_value,
mask_tensor->data<float>(),
seq_lod_data);
break;
case DataType::INT64:
FillSeqLod(batch_size,
max_seq_len_value,
mask_tensor->data<int64_t>(),
seq_lod_data);
break;
default:
PADDLE_THROW(
phi::errors::Unimplemented("Only support mask data type is int64 "
"or float, not support %s now.",
DataTypeToString(mask_tensor->dtype())));
break;
}
out->Resize({batch_size, seq_lod_data[batch_size], emb_dim});
}
for (int i = 0; i < emb_layer_num; i++) {
if (ids[i]->dtype() == DataType::INT64) {
auto* ids_data = ids[i]->data<int64_t>();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int j = 0;
j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx];
j++) {
int_ids[i][seq_lod_data[batch_idx] + j] =
ids_data[batch_idx * max_seq_len_value + j];
}
}
} else {
auto* ids_data = ids[i]->data<int>();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int j = 0;
j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx];
j++) {
int_ids[i][seq_lod_data[batch_idx] + j] =
ids_data[batch_idx * max_seq_len_value + j];
}
}
}
arg_ids.push_back(
xpu::VectorParam<int>{int_ids[i].data(), ids_len, nullptr});
}
} else {
for (int i = 0; i < emb_layer_num; i++) {
for (int j = 0; j < ids_len; j++) {
if (ids[i]->dtype() == phi::DataType::INT64) {
int_ids[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
} else if (ids[i]->dtype() == phi::DataType::INT32) {
int_ids[i][j] = ids[i]->data<int>()[j];
}
}
arg_ids.push_back(
xpu::VectorParam<int>{int_ids[i].data(), ids_len, nullptr});
}
int ids_len = id_dims[0] * id_dims[1];
std::vector<xpu::VectorParam<TID>> arg_ids;
for (int i = 0; i < emb_layer_num; i++) {
arg_ids.push_back(
xpu::VectorParam<TID>{ids[i]->data<TID>(), ids_len, nullptr});
}
ctx.template Alloc<T>(out);
if (std::is_same<T, phi::dtype::float16>::value) {
DenseTensor out_fp32_t;
ctx.template Alloc<float>(&out_fp32_t, out->numel() * sizeof(float));
int r = xpu::multi_embedding_fusion<float, float, int>(
ctx.x_context(),
arg_tables, /* tables */
out_fp32_t.data<float>(),
arg_ids,
table_lens,
emb_dim,
std::vector<float>(table_lens.size(), 1.0f),
std::vector<int>(table_lens.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
int r = xpu::multi_embedding_fusion<XPUType, XPUType, TID>(
ctx.x_context(),
arg_tables,
reinterpret_cast<XPUType*>(ctx.template Alloc<TT>(out)),
arg_ids,
table_lens,
emb_dim,
std::vector<float>(table_lens.size(), 1.0f),
std::vector<TID>(table_lens.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
}
} // namespace
r = xpu::cast(ctx.x_context(),
out_fp32_t.data<float>(),
reinterpret_cast<XPUTypeFP16*>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
} else {
int r = xpu::multi_embedding_fusion<float, float, int>(
ctx.x_context(),
arg_tables, /* tables */
out->data<float>(),
arg_ids,
table_lens,
emb_dim,
std::vector<float>(table_lens.size(), 1.0f),
std::vector<int>(table_lens.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
template <typename T, typename Context>
void EmbeddingWithEltwiseAddXpuKernel(
const Context& ctx,
const std::vector<const DenseTensor*>& ids,
const std::vector<const DenseTensor*>& tables,
const paddle::optional<DenseTensor>& mask,
int64_t padding_idx,
DenseTensor* out,
DenseTensor* seq_lod,
DenseTensor* max_seq_len) {
switch (ids[0]->dtype()) {
case DataType::INT32:
MultiEmbeddingKernel<T, int, Context>(
ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len);
break;
case DataType::INT64:
MultiEmbeddingKernel<T, int64_t, Context>(
ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Only support ids data type is int64 or int32, not support %s now.",
DataTypeToString(ids[0]->dtype())));
break;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册