From 7c9abfb2059a4b737f0038d5283a04e52a48859e Mon Sep 17 00:00:00 2001 From: Ruibin Cheung Date: Wed, 16 Aug 2023 15:42:27 +0800 Subject: [PATCH] [Fluid] NO.1 Migrate c_embedding to PHI (#56129) * [Fluid] Migrate c_embedding to PHI * fix * add python_api * fix ut * migrate xpu kernel * fix windows compile error --- .../operators/collective/c_embedding_op.cc | 15 - .../operators/collective/c_embedding_op.cu | 258 ------------------ .../operators/collective/c_embedding_op.h | 136 +-------- .../collective/c_embedding_op_xpu.cc | 119 +------- paddle/phi/kernels/c_embedding_grad_kernel.h | 28 ++ paddle/phi/kernels/c_embedding_kernel.h | 27 ++ .../kernels/cpu/c_embedding_grad_kernel.cc | 99 +++++++ paddle/phi/kernels/cpu/c_embedding_kernel.cc | 87 ++++++ .../kernels/gpu/c_embedding_grad_kernel.cu | 160 +++++++++++ paddle/phi/kernels/gpu/c_embedding_kernel.cu | 128 +++++++++ paddle/phi/kernels/xpu/c_embedding_kernel.cc | 72 +++++ .../kernels/xpu/c_embedding_kernel_grad.cc | 82 ++++++ paddle/phi/ops/compat/c_embedding_sig.cc | 30 ++ test/legacy_test/c_embedding_op_base.py | 13 +- 14 files changed, 728 insertions(+), 526 deletions(-) delete mode 100644 paddle/fluid/operators/collective/c_embedding_op.cu create mode 100644 paddle/phi/kernels/c_embedding_grad_kernel.h create mode 100644 paddle/phi/kernels/c_embedding_kernel.h create mode 100644 paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/c_embedding_kernel.cc create mode 100644 paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/c_embedding_kernel.cu create mode 100644 paddle/phi/kernels/xpu/c_embedding_kernel.cc create mode 100644 paddle/phi/kernels/xpu/c_embedding_kernel_grad.cc create mode 100644 paddle/phi/ops/compat/c_embedding_sig.cc diff --git a/paddle/fluid/operators/collective/c_embedding_op.cc b/paddle/fluid/operators/collective/c_embedding_op.cc index 7eb21ae9a46..637490e59b2 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cc +++ b/paddle/fluid/operators/collective/c_embedding_op.cc @@ -183,18 +183,3 @@ REGISTER_OPERATOR(c_embedding_grad, ops::CEmbeddingOpGrad, ops::CEmbeddingGradOpNoBufferVarsInferer, ops::CEmbeddingOpGradVarTypeInference); - -PD_REGISTER_STRUCT_KERNEL(c_embedding, - CPU, - ALL_LAYOUT, - ops::CEmbeddingOpCPUKernel, - float, - double, - plat::float16) {} -PD_REGISTER_STRUCT_KERNEL(c_embedding_grad, - CPU, - ALL_LAYOUT, - ops::CEmbeddingGradOpCPUKernel, - float, - double, - plat::float16) {} diff --git a/paddle/fluid/operators/collective/c_embedding_op.cu b/paddle/fluid/operators/collective/c_embedding_op.cu deleted file mode 100644 index 758734ada66..00000000000 --- a/paddle/fluid/operators/collective/c_embedding_op.cu +++ /dev/null @@ -1,258 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/collective/c_embedding_op.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/backends/gpu/gpu_primitives.h" -#include "paddle/phi/kernels/funcs/embedding_grad.h" - -DECLARE_int64(embedding_deterministic); - -namespace paddle { -namespace operators { - -static constexpr int kNumCUDAThreads = 512; -static constexpr int kNumMaxinumNumBlocks = 4096; - -static inline int NumBlocks(const int N) { - return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, - kNumMaxinumNumBlocks); -} - -template -__global__ void CEmbedding(T *out, - const T *table, - const IndexT *ids, - const int rows, - const int columns, - const int64_t N, - const int64_t start_idx, - const int64_t end_idx, - const int64_t limit) { - CUDA_KERNEL_LOOP(i, limit) { - size_t row = i / columns; - size_t col = i % columns; - auto id = ids[row]; - - if (id >= start_idx && id < end_idx) { - auto real_idx = id - start_idx; - PADDLE_ENFORCE(real_idx < N, - "The index is out of bounds, " - "please check whether the dimensions of index and " - "input meet the requirements. It should " - "be less than [%d], but received [%d]", - N, - real_idx); - out[i] = table[real_idx * columns + col]; - } else { - out[i] = static_cast(0); - } - } -} - -template -__global__ void CEmbeddingGrad(T *table, - const T *output, - const IndexT *ids, - const int rows, - const int columns, - const int64_t N, - const int64_t start_idx, - const int64_t end_idx, - const int64_t limit) { - CUDA_KERNEL_LOOP(i, limit) { - size_t row = i / columns; - size_t col = i % columns; - auto id = ids[row]; - if (id >= start_idx && id < end_idx) { - auto real_idx = id - start_idx; - phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]); - } - } -} - -template -class CEmbeddingCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - auto *table_t = context.Input("W"); - auto *ids_t = context.Input("Ids"); - auto *output_t = context.Output("Out"); - - const auto &dev_ctx = context.template device_context(); - const int64_t start_idx = context.Attr("start_index"); - size_t N = table_t->dims()[0]; - size_t D = table_t->dims()[1]; - size_t K = ids_t->numel(); - - const int64_t end_idx = start_idx + N; - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - auto limit = K * D; - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; - - const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - CEmbedding - <<>>(output, - table, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - - } else if (index_type == framework::proto::VarType::INT64) { - CEmbedding - <<>>(output, - table, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "GPU c_embedding ids only support int32 or int64.")); - } - } -}; - -template -class CEmbeddingGradCUDAKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - const auto &dev_ctx = context.template device_context(); - const int64_t start_idx = context.Attr("start_index"); - auto ids_t = context.Input("Ids"); - auto d_output_t = - context.Input(framework::GradVarName("Out")); - auto d_table_t = - context.Output(framework::GradVarName("W")); - - int N = d_table_t->dims()[0]; - int D = d_table_t->dims()[1]; - int K = ids_t->numel(); - - auto limit = K * D; - int blocks = NumBlocks(limit); - int threads = kNumCUDAThreads; - - const T *d_output = d_output_t->data(); - T *d_table = d_table_t->mutable_data(context.GetPlace()); - - auto t = framework::EigenVector::Flatten(*d_table_t); - t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); - - const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (FLAGS_embedding_deterministic == 1) { - if (index_type == framework::proto::VarType::INT32) { - phi::funcs::LaunchEmbeddingGradDeterministicKernel( - dev_ctx, - ids_t->data(), - d_output, - d_table, - N, - D, - K, - start_idx); - return; - } else if (index_type == framework::proto::VarType::INT64) { - phi::funcs::LaunchEmbeddingGradDeterministicKernel( - dev_ctx, - ids_t->data(), - d_output, - d_table, - N, - D, - K, - start_idx); - return; - } - } else { - if (FLAGS_embedding_deterministic > 1) { - VLOG(2) << "Run grad kernel of embedding with single thread."; - blocks = 1; - } - const int64_t end_idx = start_idx + N; - if (index_type == framework::proto::VarType::INT32) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - return; - } else if (index_type == framework::proto::VarType::INT64) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - return; - } - } - PADDLE_THROW(phi::errors::InvalidArgument( - "The data type of Input(Ids) must be int32 or int64.")); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL(c_embedding, - GPU, - ALL_LAYOUT, - ops::CEmbeddingCUDAKernel, - float, - double, -#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 - plat::bfloat16, -#endif - plat::float16) { -} - -PD_REGISTER_STRUCT_KERNEL(c_embedding_grad, - GPU, - ALL_LAYOUT, - ops::CEmbeddingGradCUDAKernel, - float, - double, -#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 - plat::bfloat16, -#endif - plat::float16) { -} diff --git a/paddle/fluid/operators/collective/c_embedding_op.h b/paddle/fluid/operators/collective/c_embedding_op.h index 4aae79f926c..dc278191152 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.h +++ b/paddle/fluid/operators/collective/c_embedding_op.h @@ -25,148 +25,16 @@ limitations under the License. */ namespace paddle { namespace operators { -inline void CheckTableValid() {} - -template -void GetIdsEmbedding(const TIds* ids, - size_t ids_len, - int64_t start_idx, - const TData* table, - int64_t height, - int64_t width, - TData* out) { - for (size_t i = 0; i < ids_len; i++) { - TIds id = ids[i]; - int64_t local = id - start_idx; - - if (local >= 0 && local < height) { - // for (int64_t w = 0; w < width; w++) { - // out[i * width + w] = table[local * width + w]; - // } - - memcpy(out + i * width, table + local * width, width * sizeof(TData)); - } else { - memset(out + i * width, 0, width * sizeof(TData)); - } - } -} - template class CEmbeddingOpCPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* table_t = ctx.Input("W"); - auto* ids_t = ctx.Input("Ids"); - auto* output_t = ctx.Output("Out"); - const int64_t start_idx = ctx.Attr("start_index"); - - VLOG(10) << "table_dims:" << table_t->dims(); - - const T* table_data = table_t->data(); - T* output_data = output_t->mutable_data(ctx.GetPlace()); - - const int64_t height = table_t->dims()[0]; - const int64_t width = table_t->dims()[1]; - - const auto& index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - GetIdsEmbedding(ids_t->data(), - ids_t->numel(), - start_idx, - table_data, - height, - width, - output_data); - } else if (index_type == framework::proto::VarType::INT64) { - GetIdsEmbedding(ids_t->data(), - ids_t->numel(), - start_idx, - table_data, - height, - width, - output_data); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "CPU c_embedding ids only support int32 or int64.")); - } - } + void Compute(const framework::ExecutionContext& ctx) const override {} }; -template -void UpdateEmbedding(const TIds* ids, - size_t ids_len, - int64_t start_idx, - TData* table, - int64_t height, - int64_t width, - const TData* out) { - for (size_t i = 0; i < ids_len; i++) { - TIds id = ids[i]; - int64_t local = id - start_idx; - - if (local >= 0 && local < height) { - for (int64_t w = 0; w < width; w++) { - table[local * width + w] += out[i * width + w]; - } - } - } -} - template class CEmbeddingGradOpCPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - const int64_t start_idx = context.Attr("start_index"); - auto ids_t = context.Input("Ids"); - auto d_output_t = - context.Input(framework::GradVarName("Out")); - auto table_t = context.Input("W"); - auto table_grad_t = - context.Output(framework::GradVarName("W")); - - T* table_grad_data = - table_grad_t->mutable_data(table_t->dims(), context.GetPlace()); - - size_t table_t_mem_size = - table_t->numel() * phi::SizeOf(table_grad_t->dtype()); - size_t table_grad_t_mem_size = - table_grad_t->numel() * - framework::SizeOfType( - framework::TransToProtoVarType(table_grad_t->dtype())); - - VLOG(10) << "table_dims:" << table_t->dims() - << ", table_t memory_size:" << table_t_mem_size - << ", table_grad_t memory_size:" << table_grad_t_mem_size - << ", start_index:" << start_idx; - - memset(table_grad_data, 0, table_grad_t_mem_size); - const T* d_output_data = d_output_t->data(); - - const int64_t height = table_t->dims()[0]; - const int64_t width = table_t->dims()[1]; - - const auto& index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - UpdateEmbedding(ids_t->data(), - ids_t->numel(), - start_idx, - table_grad_data, - height, - width, - d_output_data); - } else if (index_type == framework::proto::VarType::INT64) { - UpdateEmbedding(ids_t->data(), - ids_t->numel(), - start_idx, - table_grad_data, - height, - width, - d_output_data); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "CPU c_embedding ids only support int32 or int64.")); - } - } + void Compute(const framework::ExecutionContext& context) const override {} }; } // namespace operators diff --git a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc index b46a561532d..5bf8d0fc898 100644 --- a/paddle/fluid/operators/collective/c_embedding_op_xpu.cc +++ b/paddle/fluid/operators/collective/c_embedding_op_xpu.cc @@ -18,129 +18,14 @@ namespace operators { template class CEmbeddingOpXPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* table_t = ctx.Input("W"); - auto* ids_t = ctx.Input("Ids"); - auto* output_t = ctx.Output("Out"); - const int64_t start_index = ctx.Attr("start_index"); - const T* table_data = table_t->data(); - T* output_data = output_t->mutable_data(ctx.GetPlace()); - - const int64_t height = table_t->dims()[0]; - const int64_t width = table_t->dims()[1]; - - // int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm, - // int n, int ym, int padding_idx, TID start_index = 0); - - // xm: table height: number of entries of table. - // n: embedding dim: number of float value within single entry. - // ym: number of elements of input ids. - - auto& dev_ctx = ctx.template device_context(); - - const auto& index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - int r = xpu::embedding(dev_ctx.x_context(), - table_data, - ids_t->data(), - output_data, - height, - width, - ids_t->numel(), - -1, - static_cast(start_index)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); - } else if (index_type == framework::proto::VarType::INT64) { - int r = xpu::embedding(dev_ctx.x_context(), - table_data, - ids_t->data(), - output_data, - height, - width, - ids_t->numel(), - -1, - static_cast(start_index)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "XPU c_embedding ids only support int32 or int64.")); - } - } + void Compute(const framework::ExecutionContext& ctx) const override {} }; template class CEmbeddingGradOpXPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - const int64_t start_idx = context.Attr("start_index"); - auto ids_t = context.Input("Ids"); - auto d_output_t = - context.Input(framework::GradVarName("Out")); - auto table_t = context.Input("W"); - auto table_grad_t = - context.Output(framework::GradVarName("W")); - - auto& dev_ctx = context.template device_context(); - table_grad_t->Resize(table_t->dims()); - dev_ctx.template Alloc(table_grad_t, table_t->dtype()); - T* table_grad_data = static_cast(table_grad_t->data()); - - size_t table_t_mem_size = - table_t->numel() * phi::SizeOf(table_grad_t->dtype()); - size_t table_grad_t_mem_size = - table_grad_t->numel() * - framework::SizeOfType( - framework::TransToProtoVarType(table_grad_t->dtype())); - - VLOG(10) << "table_dims:" << table_t->dims() - << ", table_t memory_size:" << table_t_mem_size - << ", table_grad_t memory_size:" << table_grad_t_mem_size - << ", start_index:" << start_idx; - - int r = xpu::constant( - dev_ctx.x_context(), table_grad_data, table_grad_t->numel(), (T)0); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - const T* d_output_data = d_output_t->data(); - - const int64_t height = table_t->dims()[0]; - const int64_t width = table_t->dims()[1]; - - const auto& index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - r = xpu::embedding_grad(dev_ctx.x_context(), - d_output_data, - ids_t->data(), - table_grad_data, - height, - width, - ids_t->numel(), - -1, - static_cast(start_idx)); - } else if (index_type == framework::proto::VarType::INT64) { - r = xpu::embedding_grad(dev_ctx.x_context(), - d_output_data, - ids_t->data(), - table_grad_data, - height, - width, - ids_t->numel(), - -1, - static_cast(start_idx)); - } else { - PADDLE_THROW(platform::errors::Unavailable( - "XPU c_embedding ids only support int32 or int64.")); - } - PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); - } + void Compute(const framework::ExecutionContext& context) const override {} }; } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -PD_REGISTER_STRUCT_KERNEL( - c_embedding, XPU, ALL_LAYOUT, ops::CEmbeddingOpXPUKernel, float) {} -PD_REGISTER_STRUCT_KERNEL( - c_embedding_grad, XPU, ALL_LAYOUT, ops::CEmbeddingGradOpXPUKernel, float) {} diff --git a/paddle/phi/kernels/c_embedding_grad_kernel.h b/paddle/phi/kernels/c_embedding_grad_kernel.h new file mode 100644 index 00000000000..d53b1d980b5 --- /dev/null +++ b/paddle/phi/kernels/c_embedding_grad_kernel.h @@ -0,0 +1,28 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle//phi/core/dense_tensor.h" + +namespace phi { +template +void CEmbeddingGradKernel(const Context& ctx, + const DenseTensor& w, + const DenseTensor& ids, + const DenseTensor& out_grad, + int64_t start_index, + DenseTensor* w_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/c_embedding_kernel.h b/paddle/phi/kernels/c_embedding_kernel.h new file mode 100644 index 00000000000..ddd3cb45e13 --- /dev/null +++ b/paddle/phi/kernels/c_embedding_kernel.h @@ -0,0 +1,27 @@ +/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle//phi/core/dense_tensor.h" + +namespace phi { +template +void CEmbeddingKernel(const Context& ctx, + const DenseTensor& w, + const DenseTensor& ids, + int64_t start_index, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc new file mode 100644 index 00000000000..1644f998503 --- /dev/null +++ b/paddle/phi/kernels/cpu/c_embedding_grad_kernel.cc @@ -0,0 +1,99 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_grad_kernel.h" + +#include "glog/logging.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void UpdateEmbedding(const TIds* ids, + size_t ids_len, + int64_t start_idx, + TData* table, + int64_t height, + int64_t width, + const TData* out) { + for (size_t i = 0; i < ids_len; i++) { + TIds id = ids[i]; + int64_t local = id - start_idx; + + if (local >= 0 && local < height) { + for (int64_t w = 0; w < width; w++) { + table[local * width + w] += out[i * width + w]; + } + } + } +} + +template +void CEmbeddingGradKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + const DenseTensor& out_grad, + int64_t start_index, + DenseTensor* w_grad) { + w_grad->Resize(w.dims()); + T* table_grad_data = dev_ctx.template Alloc(w_grad); + + size_t table_t_mem_size = w.numel() * sizeof(w_grad->dtype()); + size_t table_grad_t_mem_size = w_grad->numel() * sizeof(w_grad->dtype()); + + VLOG(10) << "table_dims:" << w.dims() + << ", table_t memory_size:" << table_t_mem_size + << ", table_grad_t memory_size:" << table_grad_t_mem_size + << ", start_index:" << start_index; + + memset(table_grad_data, 0, table_grad_t_mem_size); + const T* d_output_data = out_grad.data(); + + const int64_t height = w.dims()[0]; + const int64_t width = w.dims()[1]; + + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32) { + UpdateEmbedding(ids.data(), + ids.numel(), + start_index, + table_grad_data, + height, + width, + d_output_data); + } else if (index_type == phi::DataType::INT64) { + UpdateEmbedding(ids.data(), + ids.numel(), + start_index, + table_grad_data, + height, + width, + d_output_data); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CPU c_embedding ids only support int32 or int64.")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(c_embedding_grad, + CPU, + ALL_LAYOUT, + phi::CEmbeddingGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/cpu/c_embedding_kernel.cc b/paddle/phi/kernels/cpu/c_embedding_kernel.cc new file mode 100644 index 00000000000..7997eed7037 --- /dev/null +++ b/paddle/phi/kernels/cpu/c_embedding_kernel.cc @@ -0,0 +1,87 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_kernel.h" + +#include "glog/logging.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void GetIdsEmbedding(const TIds* ids, + size_t ids_len, + int64_t start_idx, + const TData* table, + int64_t height, + int64_t width, + TData* out) { + for (size_t i = 0; i < ids_len; i++) { + TIds id = ids[i]; + int64_t local = id - start_idx; + + if (local >= 0 && local < height) { + memcpy(out + i * width, table + local * width, width * sizeof(TData)); + } else { + memset(out + i * width, 0, width * sizeof(TData)); + } + } +} + +template +void CEmbeddingKernel(const Context& ctx, + const DenseTensor& w, + const DenseTensor& ids, + int64_t start_index, + DenseTensor* out) { + VLOG(10) << "table_dims:" << w.dims(); + const T* table_data = w.data(); + T* output_data = ctx.template Alloc(out); + + const int64_t height = w.dims()[0]; + const int64_t width = w.dims()[1]; + + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32) { + GetIdsEmbedding(ids.data(), + ids.numel(), + start_index, + table_data, + height, + width, + output_data); + } else if (index_type == phi::DataType::INT64) { + GetIdsEmbedding(ids.data(), + ids.numel(), + start_index, + table_data, + height, + width, + output_data); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "CPU c_embedding ids only support int32 or int64.")); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(c_embedding, + CPU, + ALL_LAYOUT, + phi::CEmbeddingKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu new file mode 100644 index 00000000000..aaa5c6865be --- /dev/null +++ b/paddle/phi/kernels/gpu/c_embedding_grad_kernel.cu @@ -0,0 +1,160 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_kernel.h" + +#include "gflags/gflags.h" +#include "glog/logging.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/embedding_grad.h" + +DECLARE_int64(embedding_deterministic); + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void CEmbeddingGrad(T* table, + const T* output, + const IndexT* ids, + const int rows, + const int columns, + const int64_t N, + const int64_t start_idx, + const int64_t end_idx, + const int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + size_t row = i / columns; + size_t col = i % columns; + auto id = ids[row]; + if (id >= start_idx && id < end_idx) { + auto real_idx = id - start_idx; + phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]); + } + } +} + +template +void CEmbeddingGradKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + const DenseTensor& out_grad, + int64_t start_index, + DenseTensor* w_grad) { + int N = w_grad->dims()[0]; + int D = w_grad->dims()[1]; + int K = ids.numel(); + + auto limit = K * D; + int blocks = NumBlocks(limit); + int threads = kNumCUDAThreads; + + const T* d_output = out_grad.data(); + T* d_table = dev_ctx.template Alloc(w_grad); + + auto t = EigenVector::Flatten(*w_grad); + t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); + + const auto& index_type = ids.dtype(); + if (FLAGS_embedding_deterministic == 1) { + if (index_type == phi::DataType::INT32) { + phi::funcs::LaunchEmbeddingGradDeterministicKernel( + dev_ctx, + ids.data(), + d_output, + d_table, + N, + D, + K, + start_index); + return; + } else if (index_type == phi::DataType::INT64) { + phi::funcs::LaunchEmbeddingGradDeterministicKernel( + dev_ctx, + ids.data(), + d_output, + d_table, + N, + D, + K, + start_index); + return; + } + } else { + if (FLAGS_embedding_deterministic > 1) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + blocks = 1; + } + const int64_t end_idx = start_index + N; + if (index_type == phi::DataType::INT32) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids.data(), + K, + D, + N, + start_index, + end_idx, + limit); + return; + } else if (index_type == phi::DataType::INT64) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids.data(), + K, + D, + N, + start_index, + end_idx, + limit); + return; + } + } + PADDLE_THROW(phi::errors::InvalidArgument( + "The data type of Input(Ids) must be int32 or int64.")); +} + +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 +PD_REGISTER_KERNEL(c_embedding_grad, + GPU, + ALL_LAYOUT, + phi::CEmbeddingGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(c_embedding_grad, + GPU, + ALL_LAYOUT, + phi::CEmbeddingGradKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/gpu/c_embedding_kernel.cu b/paddle/phi/kernels/gpu/c_embedding_kernel.cu new file mode 100644 index 00000000000..b3306885437 --- /dev/null +++ b/paddle/phi/kernels/gpu/c_embedding_kernel.cu @@ -0,0 +1,128 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +template +__global__ void CEmbedding(T* out, + const T* table, + const IndexT* ids, + const int rows, + const int columns, + const int64_t N, + const int64_t start_idx, + const int64_t end_idx, + const int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + size_t row = i / columns; + size_t col = i % columns; + auto id = ids[row]; + + if (id >= start_idx && id < end_idx) { + auto real_idx = id - start_idx; + PADDLE_ENFORCE(real_idx < N, + "The index is out of bounds, " + "please check whether the dimensions of index and " + "input meet the requirements. It should " + "be less than [%d], but received [%d]", + N, + real_idx); + out[i] = table[real_idx * columns + col]; + } else { + out[i] = static_cast(0); + } + } +} + +template +void CEmbeddingKernel(const Context& ctx, + const DenseTensor& w, + const DenseTensor& ids, + int64_t start_index, + DenseTensor* out) { + size_t N = w.dims()[0]; + size_t D = w.dims()[1]; + size_t K = ids.numel(); + + const int64_t end_idx = start_index + N; + + auto* table = w.data(); + auto* output = ctx.template Alloc(out); + + auto limit = K * D; + int blocks = NumBlocks(limit); + int threads = kNumCUDAThreads; + + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32) { + CEmbedding + <<>>(output, + table, + ids.data(), + K, + D, + N, + start_index, + end_idx, + limit); + + } else if (index_type == phi::DataType::INT64) { + CEmbedding + <<>>(output, + table, + ids.data(), + K, + D, + N, + start_index, + end_idx, + limit); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "GPU c_embedding ids only support int32 or int64.")); + } +} +} // namespace phi + +#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000 +PD_REGISTER_KERNEL(c_embedding, + GPU, + ALL_LAYOUT, + phi::CEmbeddingKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} +#else +PD_REGISTER_KERNEL(c_embedding, + GPU, + ALL_LAYOUT, + phi::CEmbeddingKernel, + float, + double, + phi::dtype::float16) {} +#endif diff --git a/paddle/phi/kernels/xpu/c_embedding_kernel.cc b/paddle/phi/kernels/xpu/c_embedding_kernel.cc new file mode 100644 index 00000000000..250a2af5279 --- /dev/null +++ b/paddle/phi/kernels/xpu/c_embedding_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CEmbeddingKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + int64_t start_index, + DenseTensor* out) { + const T* table_data = w.data(); + T* output_data = dev_ctx.template Alloc(out); + + const int64_t height = w.dims()[0]; + const int64_t width = w.dims()[1]; + + // int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm, + // int n, int ym, int padding_idx, TID start_index = 0); + + // xm: table height: number of entries of table. + // n: embedding dim: number of float value within single entry. + // ym: number of elements of input ids. + + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32) { + int r = xpu::embedding(dev_ctx.x_context(), + table_data, + ids.data(), + output_data, + height, + width, + ids.numel(), + -1, + static_cast(start_index)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); + } else if (index_type == phi::DataType::INT64) { + int r = xpu::embedding(dev_ctx.x_context(), + table_data, + ids.data(), + output_data, + height, + width, + ids.numel(), + -1, + static_cast(start_index)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding"); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "XPU c_embedding ids only support int32 or int64.")); + } +} +} // namespace phi + +PD_REGISTER_KERNEL(c_embedding, XPU, ALL_LAYOUT, phi::CEmbeddingKernel, float) { +} diff --git a/paddle/phi/kernels/xpu/c_embedding_kernel_grad.cc b/paddle/phi/kernels/xpu/c_embedding_kernel_grad.cc new file mode 100644 index 00000000000..0ff1f55111a --- /dev/null +++ b/paddle/phi/kernels/xpu/c_embedding_kernel_grad.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/c_embedding_grad_kernel.h" + +#include "glog/logging.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void CEmbeddingGradKernel(const Context& dev_ctx, + const DenseTensor& w, + const DenseTensor& ids, + const DenseTensor& out_grad, + int64_t start_index, + DenseTensor* w_grad) { + w_grad->Resize(w.dims()); + dev_ctx.template Alloc(w_grad, w.dtype()); + T* table_grad_data = static_cast(w_grad->data()); + + size_t table_t_mem_size = w.numel() * phi::SizeOf(w_grad->dtype()); + size_t table_grad_t_mem_size = w_grad->numel() * phi::SizeOf(w_grad->dtype()); + + VLOG(10) << "table_dims:" << w.dims() + << ", table_t memory_size:" << table_t_mem_size + << ", table_grad_t memory_size:" << table_grad_t_mem_size + << ", start_index:" << start_index; + + int r = xpu::constant( + dev_ctx.x_context(), table_grad_data, w_grad->numel(), (T)0); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + const T* d_output_data = out_grad.data(); + + const int64_t height = w.dims()[0]; + const int64_t width = w.dims()[1]; + + const auto& index_type = ids.dtype(); + if (index_type == phi::DataType::INT32) { + r = xpu::embedding_grad(dev_ctx.x_context(), + d_output_data, + ids.data(), + table_grad_data, + height, + width, + ids.numel(), + -1, + static_cast(start_index)); + } else if (index_type == phi::DataType::INT64) { + r = xpu::embedding_grad(dev_ctx.x_context(), + d_output_data, + ids.data(), + table_grad_data, + height, + width, + ids.numel(), + -1, + static_cast(start_index)); + } else { + PADDLE_THROW(phi::errors::Unavailable( + "XPU c_embedding ids only support int32 or int64.")); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + c_embedding_grad, XPU, ALL_LAYOUT, phi::CEmbeddingGradKernel, float) {} diff --git a/paddle/phi/ops/compat/c_embedding_sig.cc b/paddle/phi/ops/compat/c_embedding_sig.cc new file mode 100644 index 00000000000..bed568433ca --- /dev/null +++ b/paddle/phi/ops/compat/c_embedding_sig.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature CEmbeddingGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("c_embedding_grad", + {"W", "Ids", "Out@GRAD"}, + {"start_index"}, + {"W@GRAD"}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(c_embedding_grad, + phi::CEmbeddingGradOpArgumentMapping); diff --git a/test/legacy_test/c_embedding_op_base.py b/test/legacy_test/c_embedding_op_base.py index 3eda046571e..891b3b82882 100644 --- a/test/legacy_test/c_embedding_op_base.py +++ b/test/legacy_test/c_embedding_op_base.py @@ -17,6 +17,7 @@ import unittest import numpy as np from eager_op_test import OpTest +import paddle from paddle.framework import core SEED = 2021 @@ -33,6 +34,12 @@ def get_c_embedding(start, end, table, ids): return output +def c_embedding_wrapper(table, index, start_index=0): + return paddle._legacy_C_ops.c_embedding( + table, index, "start_index", start_index + ) + + class TestCEmbeddingCPU(OpTest): def setUp(self): self.init_dtype() @@ -44,6 +51,7 @@ class TestCEmbeddingCPU(OpTest): def initcase(self): self.op_type = "c_embedding" + self.python_api = c_embedding_wrapper table = np.random.random((17, 64)).astype(self.dtype) ids = np.random.randint(low=0, high=17 * 2, size=(2, 4)).astype( self.ids_dtype @@ -58,10 +66,10 @@ class TestCEmbeddingCPU(OpTest): if core.is_compiled_with_xpu(): self.__class__.use_xpu = True - def test_check_cpu(self): + def test_check_output(self): self.check_output_with_place(core.CPUPlace()) - def test_check_cpu_grad(self): + def test_check_grad(self): self.check_grad_with_place(core.CPUPlace(), ['W'], 'Out') def init_dtype(self): @@ -102,6 +110,7 @@ class TestCEmbeddingOpFP32(TestCEmbeddingOpBase): def initcase(self): self.op_type = "c_embedding" + self.python_api = c_embedding_wrapper table = np.random.random((17, 64)).astype(self.dtype) ids = np.random.randint(low=0, high=17 * 2, size=(2, 4)).astype( self.ids_dtype -- GitLab