// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template void EmbeddingKernel(const Context &ctx, const DenseTensor &inputx, const DenseTensor &weight, int64_t padding_idx, DenseTensor *out) { auto *ids_t = &inputx; // int auto *output_t = out; // float PADDLE_ENFORCE_EQ( (std::is_same::value), true, phi::errors::PreconditionNotMet("Unsupported place! only support " "xpu place , please check your " "place.")); int64_t ids_numel = ids_t->numel(); auto *table_t = &weight; auto &dev_ctx = ctx; auto *table = table_t->data(); auto *output = dev_ctx.template Alloc(output_t); xpu::ctx_guard RAII_GUARD(ctx.x_context()); const int64_t *ids; if (ids_t->dtype() == phi::DataType::INT64) { ids = ids_t->data(); } else { int64_t *ids_tt = RAII_GUARD.alloc_l3_or_gm(ids_t->numel()); int r = xpu::cast( ctx.x_context(), ids_t->data(), ids_tt, ids_t->numel()); PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast"); ids = reinterpret_cast(ids_tt); } PADDLE_ENFORCE_EQ( ids_numel <= std::numeric_limits::max(), true, phi::errors::OutOfRange( "Number of ids greater than int32_t::max , please check " "number of ids in LookupTableV2XPUKernel.")); int ym = static_cast(ids_numel); size_t xm = table_t->dims()[0]; size_t n = table_t->dims()[1]; int r = xpu::embedding(dev_ctx.x_context(), table, ids, output, xm, n, ym, static_cast(padding_idx)); PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); } } // namespace phi PD_REGISTER_KERNEL(embedding, XPU, ALL_LAYOUT, phi::EmbeddingKernel, float) {}