diff --git a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc index de704bec0f7e5db5182fda3b77dabcf474e46242..c92497b5b7ac012c09a6c396927dfbc105e1b478 100644 --- a/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc +++ b/paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc @@ -19,20 +19,67 @@ namespace phi { namespace fusion { -template -void EmbeddingWithEltwiseAddXpuKernel( - const Context& ctx, - const std::vector& ids, - const std::vector& tables, - const paddle::optional& mask, - int64_t padding_idx, - DenseTensor* out, - DenseTensor* seq_lod, - DenseTensor* max_seq_len) { - using XPUType = typename XPUTypeTrait::Type; - int emb_dim = tables[0]->dims()[1]; - std::vector table_lens; - std::vector arg_tables; +namespace { +template +void FillSeqLod(int batch_size, int max_seq_len, const T* mask, int* seq_lod) { + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + int cur_batch_seq_len = 0; + for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) { + int mask_idx = batch_idx * max_seq_len + seq_idx; + if (mask[mask_idx] > 0) { + cur_batch_seq_len++; + } else { + break; + } + } + PADDLE_ENFORCE_GT( + cur_batch_seq_len, + 0, + errors::PreconditionNotMet( + "cur_batch_seq_len should be greater than 0, but got %d.", + cur_batch_seq_len)); + seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len; + } +} + +template <> +void FillSeqLod(int batch_size, + int max_seq_len, + const float* mask, + int* seq_lod) { + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + int cur_batch_seq_len = 0; + for (int seq_idx = 0; seq_idx < max_seq_len; seq_idx++) { + int mask_idx = batch_idx * max_seq_len + seq_idx; + if (fabs(mask[mask_idx]) > 1e-5) { + cur_batch_seq_len++; + } else { + break; + } + } + PADDLE_ENFORCE_GT( + cur_batch_seq_len, + 0, + errors::PreconditionNotMet( + "cur_batch_seq_len should be greater than 0, but got %d.", + cur_batch_seq_len)); + seq_lod[batch_idx + 1] = seq_lod[batch_idx] + cur_batch_seq_len; + } +} + +template +void MultiEmbeddingKernel(const Context& ctx, + const std::vector& ids, + const std::vector& tables, + const paddle::optional& mask, + int64_t padding_idx, + DenseTensor* out, + DenseTensor* seq_lod, + DenseTensor* max_seq_len) { + using XPUType = typename XPUTypeTrait::Type; + int64_t emb_dim = tables[0]->dims()[1]; + std::vector table_lens; + std::vector arg_tables; for (auto* table : tables) { auto& table_dims = table->dims(); PADDLE_ENFORCE_EQ( @@ -49,20 +96,7 @@ void EmbeddingWithEltwiseAddXpuKernel( table_dims[1], emb_dim)); table_lens.push_back(table_dims[0]); - if (std::is_same::value) { - DenseTensor table_data_fp32_t; - ctx.template Alloc(&table_data_fp32_t, - table->numel() * sizeof(float)); - int r = xpu::cast( - ctx.x_context(), - reinterpret_cast(table->data()), - table_data_fp32_t.data(), - table->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - arg_tables.push_back(table_data_fp32_t.data()); - } else { - arg_tables.push_back(table->data()); - } + arg_tables.push_back(reinterpret_cast(table->data())); } int emb_layer_num = ids.size(); @@ -78,118 +112,81 @@ void EmbeddingWithEltwiseAddXpuKernel( auto& id_dims = ids[0]->dims(); int batch_size = id_dims[0]; int max_seq_len_value = id_dims[1]; - int ids_len = id_dims[0] * id_dims[1]; - std::vector> int_ids(emb_layer_num, - std::vector(ids_len, 0)); - std::vector> arg_ids; auto* mask_tensor = mask.get_ptr(); if (mask_tensor != nullptr) { - auto mask_dtype = mask_tensor->dtype(); - PADDLE_ENFORCE( - mask_dtype == phi::DataType::INT64 || - mask_dtype == phi::DataType::FLOAT32, - errors::InvalidArgument( - "The data type of mask should be int64 or float32, but got %s.", - DataTypeToString(mask_dtype))); max_seq_len->Resize({1}); ctx.template HostAlloc(max_seq_len)[0] = max_seq_len_value; seq_lod->Resize({batch_size + 1}); int* seq_lod_data = ctx.template HostAlloc(seq_lod); seq_lod_data[0] = 0; - for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { - int cur_batch_seq_len = 0; - for (int seq_idx = 0; seq_idx < max_seq_len_value; seq_idx++) { - int mask_idx = batch_idx * max_seq_len_value + seq_idx; - if ((mask_dtype == phi::DataType::INT64 && - mask->data()[mask_idx] > 0) || - (mask_dtype == phi::DataType::FLOAT32 && - fabs(mask->data()[mask_idx]) > 1e-5)) { - cur_batch_seq_len++; - } else { - break; - } - } - PADDLE_ENFORCE_GT( - cur_batch_seq_len, - 0, - errors::PreconditionNotMet( - "cur_batch_seq_len should be greater than 0, but got %d.", - cur_batch_seq_len)); - seq_lod_data[batch_idx + 1] = seq_lod_data[batch_idx] + cur_batch_seq_len; + switch (mask_tensor->dtype()) { + case DataType::FLOAT32: + FillSeqLod(batch_size, + max_seq_len_value, + mask_tensor->data(), + seq_lod_data); + break; + case DataType::INT64: + FillSeqLod(batch_size, + max_seq_len_value, + mask_tensor->data(), + seq_lod_data); + break; + default: + PADDLE_THROW( + phi::errors::Unimplemented("Only support mask data type is int64 " + "or float, not support %s now.", + DataTypeToString(mask_tensor->dtype()))); + break; } out->Resize({batch_size, seq_lod_data[batch_size], emb_dim}); + } - for (int i = 0; i < emb_layer_num; i++) { - if (ids[i]->dtype() == DataType::INT64) { - auto* ids_data = ids[i]->data(); - for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { - for (int j = 0; - j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx]; - j++) { - int_ids[i][seq_lod_data[batch_idx] + j] = - ids_data[batch_idx * max_seq_len_value + j]; - } - } - } else { - auto* ids_data = ids[i]->data(); - for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { - for (int j = 0; - j < seq_lod_data[batch_idx + 1] - seq_lod_data[batch_idx]; - j++) { - int_ids[i][seq_lod_data[batch_idx] + j] = - ids_data[batch_idx * max_seq_len_value + j]; - } - } - } - arg_ids.push_back( - xpu::VectorParam{int_ids[i].data(), ids_len, nullptr}); - } - } else { - for (int i = 0; i < emb_layer_num; i++) { - for (int j = 0; j < ids_len; j++) { - if (ids[i]->dtype() == phi::DataType::INT64) { - int_ids[i][j] = static_cast(ids[i]->data()[j]); - } else if (ids[i]->dtype() == phi::DataType::INT32) { - int_ids[i][j] = ids[i]->data()[j]; - } - } - arg_ids.push_back( - xpu::VectorParam{int_ids[i].data(), ids_len, nullptr}); - } + int ids_len = id_dims[0] * id_dims[1]; + std::vector> arg_ids; + for (int i = 0; i < emb_layer_num; i++) { + arg_ids.push_back( + xpu::VectorParam{ids[i]->data(), ids_len, nullptr}); } - ctx.template Alloc(out); - if (std::is_same::value) { - DenseTensor out_fp32_t; - ctx.template Alloc(&out_fp32_t, out->numel() * sizeof(float)); - int r = xpu::multi_embedding_fusion( - ctx.x_context(), - arg_tables, /* tables */ - out_fp32_t.data(), - arg_ids, - table_lens, - emb_dim, - std::vector(table_lens.size(), 1.0f), - std::vector(table_lens.size(), padding_idx)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); + int r = xpu::multi_embedding_fusion( + ctx.x_context(), + arg_tables, + reinterpret_cast(ctx.template Alloc(out)), + arg_ids, + table_lens, + emb_dim, + std::vector(table_lens.size(), 1.0f), + std::vector(table_lens.size(), padding_idx)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); +} +} // namespace - r = xpu::cast(ctx.x_context(), - out_fp32_t.data(), - reinterpret_cast(out->data()), - out->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); - } else { - int r = xpu::multi_embedding_fusion( - ctx.x_context(), - arg_tables, /* tables */ - out->data(), - arg_ids, - table_lens, - emb_dim, - std::vector(table_lens.size(), 1.0f), - std::vector(table_lens.size(), padding_idx)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_with_eltwise_add_xpu"); +template +void EmbeddingWithEltwiseAddXpuKernel( + const Context& ctx, + const std::vector& ids, + const std::vector& tables, + const paddle::optional& mask, + int64_t padding_idx, + DenseTensor* out, + DenseTensor* seq_lod, + DenseTensor* max_seq_len) { + switch (ids[0]->dtype()) { + case DataType::INT32: + MultiEmbeddingKernel( + ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len); + break; + case DataType::INT64: + MultiEmbeddingKernel( + ctx, ids, tables, mask, padding_idx, out, seq_lod, max_seq_len); + break; + default: + PADDLE_THROW(phi::errors::Unimplemented( + "Only support ids data type is int64 or int32, not support %s now.", + DataTypeToString(ids[0]->dtype()))); + break; } }