diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc index 4a6b1014277d70d013788fcb84ce5cc47eb2d8ec..8402931ed89d26588fa8f95597a2bb52ddb13ba6 100644 --- a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -15,21 +15,20 @@ #include "paddle/phi/kernels/embedding_grad_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" namespace phi { template -struct LookupTableV2GradCPUFunctor { - LookupTableV2GradCPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - DenseTensor* weight_grad) +struct EmbeddingGradCPUFunctor { + EmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -48,7 +47,6 @@ struct LookupTableV2GradCPUFunctor { // paddings makes no sense and we don't deal with it in backward. { auto* d_output = &out_grad_; - // auto d_table = weight_grad_; auto* ids_data = ids.data(); int64_t N = table_dim[0]; @@ -70,7 +68,8 @@ struct LookupTableV2GradCPUFunctor { ids_data[i], N, phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", N, @@ -79,7 +78,8 @@ struct LookupTableV2GradCPUFunctor { ids_data[i], 0, phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", N, @@ -108,20 +108,25 @@ void EmbeddingGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, DenseTensor* weight_grad) { - LookupTableV2GradCPUFunctor functor( + EmbeddingGradCPUFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } template -struct LookupTableV2SparseGradCPUFunctor { - LookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - SelectedRows* weight_grad) +struct EmbeddingSparseGradCPUFunctor { + EmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -145,7 +150,7 @@ struct LookupTableV2SparseGradCPUFunctor { auto* d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->template mutable_data(dev_ctx_.GetPlace()); + dev_ctx_.template Alloc(d_table_value); d_table->set_height(table_dim[0]); @@ -183,10 +188,15 @@ void EmbeddingSparseGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, SelectedRows* weight_grad) { - LookupTableV2SparseGradCPUFunctor functor( + EmbeddingSparseGradCPUFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc index 6c92e9a660a7e50bb9cbe91a350b88cb884290f6..2973f426aa5ef24ced571a14906f5765a01b361c 100644 --- a/paddle/phi/kernels/cpu/embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -15,20 +15,20 @@ #include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" namespace phi { template -struct LookupTableV2CPUFunctor { - LookupTableV2CPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& weight, - int64_t padding_idx, - DenseTensor* out) +struct EmbeddingCPUFunctor { + EmbeddingCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -91,10 +91,15 @@ void EmbeddingKernel(const Context& ctx, const DenseTensor& weight, int64_t padding_idx, DenseTensor* out) { - LookupTableV2CPUFunctor functor( - ctx, input, weight, padding_idx, out); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + EmbeddingCPUFunctor functor(ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc index 89237d3f6e87696c0efc59b0818a41913bd147d0..4846f209660bab20cfa66b91a5b4ae56342f372b 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc @@ -15,21 +15,21 @@ #include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" namespace phi { template -struct SparseWeightLookupTableV2GradCPUFunctor { - SparseWeightLookupTableV2GradCPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const SelectedRows& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - DenseTensor* weight_grad) +struct SparseWeightEmbeddingGradCPUFunctor { + SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -70,7 +70,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor { ids_data[i], N, phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", N, @@ -79,7 +80,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor { ids_data[i], 0, phi::errors::InvalidArgument( - "Variable value (input) of OP(fluid.layers.embedding) " + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " "expected >= 0 and < %ld, but got %ld. Please check input " "value.", N, @@ -102,13 +104,13 @@ struct SparseWeightLookupTableV2GradCPUFunctor { }; template -struct SparseWeightLookupTableV2SparseGradCPUFunctor { - SparseWeightLookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, - const DenseTensor& input, - const SelectedRows& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - SelectedRows* weight_grad) +struct SparseWeightEmbeddingSparseGradCPUFunctor { + SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -132,7 +134,7 @@ struct SparseWeightLookupTableV2SparseGradCPUFunctor { auto* d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table_dim[1]}); - d_table_value->template mutable_data(dev_ctx_.GetPlace()); + dev_ctx_.template Alloc(d_table_value); d_table->set_height(table_dim[0]); @@ -170,10 +172,16 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, DenseTensor* weight_grad) { - SparseWeightLookupTableV2GradCPUFunctor functor( + SparseWeightEmbeddingGradCPUFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } template @@ -183,10 +191,16 @@ void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, SelectedRows* weight_grad) { - SparseWeightLookupTableV2SparseGradCPUFunctor functor( + SparseWeightEmbeddingSparseGradCPUFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc index d8a53f42f606d3c964033862fe4f2896754eb6c1..fb29feee24ad5b6328070f8b20991610e37d412d 100644 --- a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc @@ -15,21 +15,21 @@ #include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { template -struct LookupTableV2CPUSparseFunctor { - LookupTableV2CPUSparseFunctor(const Context& dev_ctx, - const DenseTensor& input, - const SelectedRows& weight, - int64_t padding_idx, - DenseTensor* out) +struct EmbeddingCPUSparseFunctor { + EmbeddingCPUSparseFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -45,7 +45,7 @@ struct LookupTableV2CPUSparseFunctor { auto output_t = out_; int64_t row_width = table_t.value().dims()[1]; const auto* table = table_t.value().template data(); - auto* output = output_t->template mutable_data(dev_ctx_.GetPlace()); + auto* output = dev_ctx_.template Alloc(output_t); auto input_data_type = paddle::framework::TransToProtoVarType(table_t.value().dtype()); @@ -94,10 +94,16 @@ void SparseWeightEmbeddingKernel(const Context& ctx, const SelectedRows& weight, int64_t padding_idx, DenseTensor* out) { - LookupTableV2CPUSparseFunctor functor( + EmbeddingCPUSparseFunctor functor( ctx, input, weight, padding_idx, out); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 5a3715f454980e56e06f697c7335104eb761d28d..def7988ba519950c64753b58676870d54bc792cd 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -15,9 +15,9 @@ #include "paddle/phi/kernels/embedding_grad_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/memory/memcpy.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" @@ -36,12 +36,12 @@ __global__ void InputTypeConvert(const InT* in_ids, } template -__global__ void LookupTableV2Grad(T* table, - const T* output, - const IdT* ids, - const int64_t N, - const int64_t K, - const int64_t D) { +__global__ void EmbeddingGrad(T* table, + const T* output, + const IdT* ids, + const int64_t N, + const int64_t K, + const int64_t D) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDim.x; @@ -61,13 +61,13 @@ __global__ void LookupTableV2Grad(T* table, } template -struct LookupTableV2GradCUDAFunctor { - LookupTableV2GradCUDAFunctor(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - DenseTensor* weight_grad) +struct EmbeddingGradCUDAFunctor { + EmbeddingGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -89,7 +89,7 @@ struct LookupTableV2GradCUDAFunctor { const T* d_output = d_output_t.template data(); const auto* ids = input_.template data(); - T* d_table = d_table_t->mutable_data(dev_ctx_.GetPlace()); + T* d_table = dev_ctx_.template Alloc(d_table_t); #ifdef PADDLE_WITH_HIP PADDLE_ENFORCE_GPU_SUCCESS( @@ -102,7 +102,7 @@ struct LookupTableV2GradCUDAFunctor { const int gridx = 2 * dev_ctx_.GetSMCount(); dim3 threads(128, 8); dim3 grids(gridx, 1); - LookupTableV2Grad<<>>( + EmbeddingGrad<<>>( d_table, d_output, ids, N, K, D); } } @@ -123,20 +123,26 @@ void EmbeddingGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, DenseTensor* weight_grad) { - LookupTableV2GradCUDAFunctor functor( + EmbeddingGradCUDAFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } template -struct LookupTableV2SparseGradCUDAFunctor { - LookupTableV2SparseGradCUDAFunctor(const Context& dev_ctx, - const DenseTensor& input, - const DenseTensor& weight, - const DenseTensor& out_grad, - int64_t padding_idx, - SelectedRows* weight_grad) +struct EmbeddingSparseGradCUDAFunctor { + EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -179,7 +185,7 @@ struct LookupTableV2SparseGradCUDAFunctor { auto* d_table_value = d_table->mutable_value(); d_table_value->Resize({ids_num, table->dims()[1]}); - d_table_value->template mutable_data(gpu_place); + dev_ctx_.template Alloc(d_table_value); auto* d_table_data = d_table_value->template data(); auto* d_output_data = d_output->template data(); @@ -219,10 +225,16 @@ void EmbeddingSparseGradKernel(const Context& ctx, const DenseTensor& out_grad, int64_t padding_idx, SelectedRows* weight_grad) { - LookupTableV2SparseGradCUDAFunctor functor( + EmbeddingSparseGradCUDAFunctor functor( ctx, input, weight, out_grad, padding_idx, weight_grad); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu index 0f66dbf59151b8991eb7e70a064e62b16287466a..4d349c6ab5cc3806ba97e076a81329277e6a3cdc 100644 --- a/paddle/phi/kernels/gpu/embedding_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -15,22 +15,21 @@ #include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/funcs/embedding_util.h" -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h" namespace phi { template -__global__ void LookupTableV2(T *output, - const T *table, - const IdT *ids, - const int64_t N, - const int64_t K, - const int64_t D, - const int64_t padding_idx) { +__global__ void EmbeddingFW(T *output, + const T *table, + const IdT *ids, + const int64_t N, + const int64_t K, + const int64_t D, + const int64_t padding_idx) { int idx = threadIdx.x; int idy = blockIdx.x + threadIdx.y * gridDim.x; @@ -53,12 +52,12 @@ __global__ void LookupTableV2(T *output, } template -struct LookupTableV2CUDAFunctor { - LookupTableV2CUDAFunctor(const Context &dev_ctx, - const DenseTensor &input, - const DenseTensor &weight, - int64_t padding_idx, - DenseTensor *out) +struct EmbeddingCUDAFunctor { + EmbeddingCUDAFunctor(const Context &dev_ctx, + const DenseTensor &input, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) : dev_ctx_(dev_ctx), input_(input), weight_(weight), @@ -77,14 +76,14 @@ struct LookupTableV2CUDAFunctor { const T *table = weight_.template data(); const IdT *ids = input_.template data(); - auto *output = out_->template mutable_data(dev_ctx_.GetPlace()); + auto *output = dev_ctx_.template Alloc(out_); auto stream = dev_ctx_.stream(); if (padding_idx_ == -1) { - LookupTableV2<<>>( + EmbeddingFW<<>>( output, table, ids, N, K, D, padding_idx_); } else { - LookupTableV2<<>>( + EmbeddingFW<<>>( output, table, ids, N, K, D, padding_idx_); } } @@ -103,10 +102,16 @@ void EmbeddingKernel(const Context &ctx, const DenseTensor &weight, int64_t padding_idx, DenseTensor *out) { - LookupTableV2CUDAFunctor functor( + EmbeddingCUDAFunctor functor( ctx, input, weight, padding_idx, out); - paddle::framework::VisitIntDataType( - paddle::framework::TransToProtoVarType(input.dtype()), functor); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW("emebdding input only support int32 and int64"); + } } } // namespace phi