未验证 提交 95aab366 编写于 作者: Z zhupengyang 提交者: GitHub

fix embedding_with_eltwise_add_xpu (#55354)

上级 fce4c2de
...@@ -19,9 +19,56 @@ ...@@ -19,9 +19,56 @@
namespace phi { namespace phi {
namespace fusion { namespace fusion {
template <typename T, typename Context> namespace {
void EmbeddingWithEltwiseAddXpuKernel( template <typename T>
const Context& ctx, void FillSeqLod(int batch_size, int max_seq_len, const T* mask, int* seq_lod) {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cur_batch_seq_len = 0;
for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) {
int mask_idx = batch_idx * max_seq_len + seq_idx;
if (mask[mask_idx] > 0) {
cur_batch_seq_len++;
} else {
break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len;
}
}
template <>
void FillSeqLod<float>(int batch_size,
int max_seq_len,
const float* mask,
int* seq_lod) {
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
int cur_batch_seq_len = 0;
for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) {
int mask_idx = batch_idx * max_seq_len + seq_idx;
if (fabs(mask[mask_idx]) > 1e-5) {
cur_batch_seq_len++;
} else {
break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len;
}
}
template <typename TT, typename TID, typename Context>
void MultiEmbeddingKernel(const Context& ctx,
const std::vector<const DenseTensor*>& ids, const std::vector<const DenseTensor*>& ids,
const std::vector<const DenseTensor*>& tables, const std::vector<const DenseTensor*>& tables,
const paddle::optional<DenseTensor>& mask, const paddle::optional<DenseTensor>& mask,
...@@ -29,10 +76,10 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -29,10 +76,10 @@ void EmbeddingWithEltwiseAddXpuKernel(
DenseTensor* out, DenseTensor* out,
DenseTensor* seq_lod, DenseTensor* seq_lod,
DenseTensor* max_seq_len) { DenseTensor* max_seq_len) {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<TT>::Type;
int emb_dim = tables[0]->dims()[1]; int64_t emb_dim = tables[0]->dims()[1];
std::vector<int> table_lens; std::vector<TID> table_lens;
std::vector<const float*> arg_tables; std::vector<const XPUType*> arg_tables;
for (auto* table : tables) { for (auto* table : tables) {
auto& table_dims = table->dims(); auto& table_dims = table->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -49,20 +96,7 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -49,20 +96,7 @@ void EmbeddingWithEltwiseAddXpuKernel(
table_dims[1], table_dims[1],
emb_dim)); emb_dim));
table_lens.push_back(table_dims[0]); table_lens.push_back(table_dims[0]);
if (std::is_same<T, phi::dtype::float16>::value) { arg_tables.push_back(reinterpret_cast<const XPUType*>(table->data<TT>()));
DenseTensor table_data_fp32_t;
ctx.template Alloc<float>(&table_data_fp32_t,
table->numel() * sizeof(float));
int r = xpu::cast<XPUType, float>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(table->data<T>()),
table_data_fp32_t.data<float>(),
table->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast");
arg_tables.push_back(table_data_fp32_t.data<float>());
} else {
arg_tables.push_back(table->data<float>());
}
} }
int emb_layer_num = ids.size(); int emb_layer_num = ids.size();
...@@ -78,118 +112,81 @@ void EmbeddingWithEltwiseAddXpuKernel( ...@@ -78,118 +112,81 @@ void EmbeddingWithEltwiseAddXpuKernel(
auto& id_dims = ids[0]->dims(); auto& id_dims = ids[0]->dims();
int batch_size = id_dims[0]; int batch_size = id_dims[0];
int max_seq_len_value = id_dims[1]; int max_seq_len_value = id_dims[1];
int ids_len = id_dims[0] * id_dims[1];
std::vector<std::vector<int>> int_ids(emb_layer_num,
std::vector<int>(ids_len, 0));
std::vector<xpu::VectorParam<int>> arg_ids;
auto* mask_tensor = mask.get_ptr(); auto* mask_tensor = mask.get_ptr();
if (mask_tensor != nullptr) { if (mask_tensor != nullptr) {
auto mask_dtype = mask_tensor->dtype();
PADDLE_ENFORCE(
mask_dtype == phi::DataType::INT64 ||
mask_dtype == phi::DataType::FLOAT32,
errors::InvalidArgument(
"The data type of mask should be int64 or float32, but got %s.",
DataTypeToString(mask_dtype)));
max_seq_len->Resize({1}); max_seq_len->Resize({1});
ctx.template HostAlloc<int>(max_seq_len)[0] = max_seq_len_value; ctx.template HostAlloc<int>(max_seq_len)[0] = max_seq_len_value;
seq_lod->Resize({batch_size + 1}); seq_lod->Resize({batch_size + 1});
int* seq_lod_data = ctx.template HostAlloc<int>(seq_lod); int* seq_lod_data = ctx.template HostAlloc<int>(seq_lod);
seq_lod_data[0] = 0; seq_lod_data[0] = 0;
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { switch (mask_tensor->dtype()) {
int cur_batch_seq_len = 0; case DataType::FLOAT32:
for (int seq_idx = 0; seq_idx < max_seq_len_value; seq_idx++) { FillSeqLod(batch_size,
int mask_idx = batch_idx * max_seq_len_value + seq_idx; max_seq_len_value,
if ((mask_dtype == phi::DataType::INT64 && mask_tensor->data<float>(),
mask->data<int64_t>()[mask_idx] > 0) || seq_lod_data);
(mask_dtype == phi::DataType::FLOAT32 && break;
fabs(mask->data<float>()[mask_idx]) > 1e-5)) { case DataType::INT64:
cur_batch_seq_len++; FillSeqLod(batch_size,
} else { max_seq_len_value,
mask_tensor->data<int64_t>(),
seq_lod_data);
break;
default:
PADDLE_THROW(
phi::errors::Unimplemented("Only support mask data type is int64 "
"or float, not support %s now.",
DataTypeToString(mask_tensor->dtype())));
break; break;
}
}
PADDLE_ENFORCE_GT(
cur_batch_seq_len,
0,
errors::PreconditionNotMet(
"cur_batch_seq_len should be greater than 0, but got %d.",
cur_batch_seq_len));
seq_lod_data[batch_idx + 1] = seq_lod_data[batch_idx] + cur_batch_seq_len;
} }
out->Resize({batch_size, seq_lod_data[batch_size], emb_dim}); out->Resize({batch_size, seq_lod_data[batch_size], emb_dim});
for (int i = 0; i < emb_layer_num; i++) {
if (ids[i]->dtype() == DataType::INT64) {
auto* ids_data = ids[i]->data<int64_t>();
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) {
for (int j = 0;
j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx];
j++) {
int_ids[i][seq_lod_data[batch_idx] + j] =
ids_data[batch_idx * max_seq_len_value + j];
}
} }
} else {
auto* ids_data = ids[i]->data<int>(); int ids_len = id_dims[0] * id_dims[1];
for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { std::vector<xpu::VectorParam<TID>> arg_ids;
for (int j = 0;
j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx];
j++) {
int_ids[i][seq_lod_data[batch_idx] + j] =
ids_data[batch_idx * max_seq_len_value + j];
}
}
}
arg_ids.push_back(
xpu::VectorParam<int>{int_ids[i].data(), ids_len, nullptr});
}
} else {
for (int i = 0; i < emb_layer_num; i++) { for (int i = 0; i < emb_layer_num; i++) {
for (int j = 0; j < ids_len; j++) {
if (ids[i]->dtype() == phi::DataType::INT64) {
int_ids[i][j] = static_cast<int>(ids[i]->data<int64_t>()[j]);
} else if (ids[i]->dtype() == phi::DataType::INT32) {
int_ids[i][j] = ids[i]->data<int>()[j];
}
}
arg_ids.push_back( arg_ids.push_back(
xpu::VectorParam<int>{int_ids[i].data(), ids_len, nullptr}); xpu::VectorParam<TID>{ids[i]->data<TID>(), ids_len, nullptr});
}
} }
ctx.template Alloc<T>(out); int r = xpu::multi_embedding_fusion<XPUType, XPUType, TID>(
if (std::is_same<T, phi::dtype::float16>::value) {
DenseTensor out_fp32_t;
ctx.template Alloc<float>(&out_fp32_t, out->numel() * sizeof(float));
int r = xpu::multi_embedding_fusion<float, float, int>(
ctx.x_context(), ctx.x_context(),
arg_tables, /* tables */ arg_tables,
out_fp32_t.data<float>(), reinterpret_cast<XPUType*>(ctx.template Alloc<TT>(out)),
arg_ids, arg_ids,
table_lens, table_lens,
emb_dim, emb_dim,
std::vector<float>(table_lens.size(), 1.0f), std::vector<float>(table_lens.size(), 1.0f),
std::vector<int>(table_lens.size(), padding_idx)); std::vector<TID>(table_lens.size(), padding_idx));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu");
}
} // namespace
r = xpu::cast(ctx.x_context(), template <typename T, typename Context>
out_fp32_t.data<float>(), void EmbeddingWithEltwiseAddXpuKernel(
reinterpret_cast<XPUTypeFP16*>(out->data<T>()), const Context& ctx,
out->numel()); const std::vector<const DenseTensor*>& ids,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); const std::vector<const DenseTensor*>& tables,
} else { const paddle::optional<DenseTensor>& mask,
int r = xpu::multi_embedding_fusion<float, float, int>( int64_t padding_idx,
ctx.x_context(), DenseTensor* out,
arg_tables, /* tables */ DenseTensor* seq_lod,
out->data<float>(), DenseTensor* max_seq_len) {
arg_ids, switch (ids[0]->dtype()) {
table_lens, case DataType::INT32:
emb_dim, MultiEmbeddingKernel<T, int, Context>(
std::vector<float>(table_lens.size(), 1.0f), ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len);
std::vector<int>(table_lens.size(), padding_idx)); break;
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); case DataType::INT64:
MultiEmbeddingKernel<T, int64_t, Context>(
ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len);
break;
default:
PADDLE_THROW(phi::errors::Unimplemented(
"Only support ids data type is int64 or int32, not support %s now.",
DataTypeToString(ids[0]->dtype())));
break;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册