提交 867fc053 编写于 作者: P phlrain

polish code

上级 7ba14d74
...@@ -15,21 +15,20 @@ ...@@ -15,21 +15,20 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h" #include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2GradCPUFunctor { struct EmbeddingGradCPUFunctor {
LookupTableV2GradCPUFunctor(const Context& dev_ctx, EmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) DenseTensor* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -48,7 +47,6 @@ struct LookupTableV2GradCPUFunctor { ...@@ -48,7 +47,6 @@ struct LookupTableV2GradCPUFunctor {
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
{ {
auto* d_output = &out_grad_; auto* d_output = &out_grad_;
// auto d_table = weight_grad_;
auto* ids_data = ids.data(); auto* ids_data = ids.data();
int64_t N = table_dim[0]; int64_t N = table_dim[0];
...@@ -70,7 +68,8 @@ struct LookupTableV2GradCPUFunctor { ...@@ -70,7 +68,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data[i], ids_data[i],
N, N,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
N, N,
...@@ -79,7 +78,8 @@ struct LookupTableV2GradCPUFunctor { ...@@ -79,7 +78,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data[i], ids_data[i],
0, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
N, N,
...@@ -108,20 +108,25 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -108,20 +108,25 @@ void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) { DenseTensor* weight_grad) {
LookupTableV2GradCPUFunctor<T, Context> functor( EmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType( if (input.dtype() == phi::DataType::INT32) {
paddle::framework::TransToProtoVarType(input.dtype()), functor); functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2SparseGradCPUFunctor { struct EmbeddingSparseGradCPUFunctor {
LookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, EmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) SelectedRows* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -145,7 +150,7 @@ struct LookupTableV2SparseGradCPUFunctor { ...@@ -145,7 +150,7 @@ struct LookupTableV2SparseGradCPUFunctor {
auto* d_table_value = d_table->mutable_value(); auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace()); dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]); d_table->set_height(table_dim[0]);
...@@ -183,10 +188,15 @@ void EmbeddingSparseGradKernel(const Context& ctx, ...@@ -183,10 +188,15 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) { SelectedRows* weight_grad) {
LookupTableV2SparseGradCPUFunctor<T, Context> functor( EmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType( if (input.dtype() == phi::DataType::INT32) {
paddle::framework::TransToProtoVarType(input.dtype()), functor); functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
...@@ -15,20 +15,20 @@ ...@@ -15,20 +15,20 @@
#include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2CPUFunctor { struct EmbeddingCPUFunctor {
LookupTableV2CPUFunctor(const Context& dev_ctx, EmbeddingCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const DenseTensor& weight, const DenseTensor& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) DenseTensor* out)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -91,10 +91,15 @@ void EmbeddingKernel(const Context& ctx, ...@@ -91,10 +91,15 @@ void EmbeddingKernel(const Context& ctx,
const DenseTensor& weight, const DenseTensor& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) { DenseTensor* out) {
LookupTableV2CPUFunctor<T, Context> functor( EmbeddingCPUFunctor<T, Context> functor(ctx, input, weight, padding_idx, out);
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType( if (input.dtype() == phi::DataType::INT32) {
paddle::framework::TransToProtoVarType(input.dtype()), functor); functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
...@@ -15,21 +15,21 @@ ...@@ -15,21 +15,21 @@
#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h" #include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct SparseWeightLookupTableV2GradCPUFunctor { struct SparseWeightEmbeddingGradCPUFunctor {
SparseWeightLookupTableV2GradCPUFunctor(const Context& dev_ctx, SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const SelectedRows& weight, const SelectedRows& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) DenseTensor* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -70,7 +70,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor { ...@@ -70,7 +70,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data[i], ids_data[i],
N, N,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
N, N,
...@@ -79,7 +80,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor { ...@@ -79,7 +80,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data[i], ids_data[i],
0, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) " "Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
N, N,
...@@ -102,13 +104,13 @@ struct SparseWeightLookupTableV2GradCPUFunctor { ...@@ -102,13 +104,13 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
}; };
template <typename T, typename Context> template <typename T, typename Context>
struct SparseWeightLookupTableV2SparseGradCPUFunctor { struct SparseWeightEmbeddingSparseGradCPUFunctor {
SparseWeightLookupTableV2SparseGradCPUFunctor(const Context& dev_ctx, SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const SelectedRows& weight, const SelectedRows& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) SelectedRows* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -132,7 +134,7 @@ struct SparseWeightLookupTableV2SparseGradCPUFunctor { ...@@ -132,7 +134,7 @@ struct SparseWeightLookupTableV2SparseGradCPUFunctor {
auto* d_table_value = d_table->mutable_value(); auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace()); dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]); d_table->set_height(table_dim[0]);
...@@ -170,10 +172,16 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx, ...@@ -170,10 +172,16 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) { DenseTensor* weight_grad) {
SparseWeightLookupTableV2GradCPUFunctor<T, Context> functor( SparseWeightEmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
...@@ -183,10 +191,16 @@ void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, ...@@ -183,10 +191,16 @@ void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) { SelectedRows* weight_grad) {
SparseWeightLookupTableV2SparseGradCPUFunctor<T, Context> functor( SparseWeightEmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
...@@ -15,21 +15,21 @@ ...@@ -15,21 +15,21 @@
#include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2CPUSparseFunctor { struct EmbeddingCPUSparseFunctor {
LookupTableV2CPUSparseFunctor(const Context& dev_ctx, EmbeddingCPUSparseFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const SelectedRows& weight, const SelectedRows& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) DenseTensor* out)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -45,7 +45,7 @@ struct LookupTableV2CPUSparseFunctor { ...@@ -45,7 +45,7 @@ struct LookupTableV2CPUSparseFunctor {
auto output_t = out_; auto output_t = out_;
int64_t row_width = table_t.value().dims()[1]; int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().template data<T>(); const auto* table = table_t.value().template data<T>();
auto* output = output_t->template mutable_data<T>(dev_ctx_.GetPlace()); auto* output = dev_ctx_.template Alloc<T>(output_t);
auto input_data_type = auto input_data_type =
paddle::framework::TransToProtoVarType(table_t.value().dtype()); paddle::framework::TransToProtoVarType(table_t.value().dtype());
...@@ -94,10 +94,16 @@ void SparseWeightEmbeddingKernel(const Context& ctx, ...@@ -94,10 +94,16 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
const SelectedRows& weight, const SelectedRows& weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* out) { DenseTensor* out) {
LookupTableV2CPUSparseFunctor<T, Context> functor( EmbeddingCPUSparseFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out); ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h" #include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
...@@ -36,12 +36,12 @@ __global__ void InputTypeConvert(const InT* in_ids, ...@@ -36,12 +36,12 @@ __global__ void InputTypeConvert(const InT* in_ids,
} }
template <typename T, typename IdT> template <typename T, typename IdT>
__global__ void LookupTableV2Grad(T* table, __global__ void EmbeddingGrad(T* table,
const T* output, const T* output,
const IdT* ids, const IdT* ids,
const int64_t N, const int64_t N,
const int64_t K, const int64_t K,
const int64_t D) { const int64_t D) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDim.x; int idy = blockIdx.x + threadIdx.y * gridDim.x;
...@@ -61,13 +61,13 @@ __global__ void LookupTableV2Grad(T* table, ...@@ -61,13 +61,13 @@ __global__ void LookupTableV2Grad(T* table,
} }
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2GradCUDAFunctor { struct EmbeddingGradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const Context& dev_ctx, EmbeddingGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) DenseTensor* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -89,7 +89,7 @@ struct LookupTableV2GradCUDAFunctor { ...@@ -89,7 +89,7 @@ struct LookupTableV2GradCUDAFunctor {
const T* d_output = d_output_t.template data<T>(); const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>(); const auto* ids = input_.template data<IdT>();
T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace()); T* d_table = dev_ctx_.template Alloc<T>(d_table_t);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -102,7 +102,7 @@ struct LookupTableV2GradCUDAFunctor { ...@@ -102,7 +102,7 @@ struct LookupTableV2GradCUDAFunctor {
const int gridx = 2 * dev_ctx_.GetSMCount(); const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(gridx, 1); dim3 grids(gridx, 1);
LookupTableV2Grad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>( EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
} }
} }
...@@ -123,20 +123,26 @@ void EmbeddingGradKernel(const Context& ctx, ...@@ -123,20 +123,26 @@ void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
DenseTensor* weight_grad) { DenseTensor* weight_grad) {
LookupTableV2GradCUDAFunctor<T, Context> functor( EmbeddingGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2SparseGradCUDAFunctor { struct EmbeddingSparseGradCUDAFunctor {
LookupTableV2SparseGradCUDAFunctor(const Context& dev_ctx, EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input, const DenseTensor& input,
const DenseTensor& weight, const DenseTensor& weight,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) SelectedRows* weight_grad)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -179,7 +185,7 @@ struct LookupTableV2SparseGradCUDAFunctor { ...@@ -179,7 +185,7 @@ struct LookupTableV2SparseGradCUDAFunctor {
auto* d_table_value = d_table->mutable_value(); auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]}); d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->template mutable_data<T>(gpu_place); dev_ctx_.template Alloc<T>(d_table_value);
auto* d_table_data = d_table_value->template data<T>(); auto* d_table_data = d_table_value->template data<T>();
auto* d_output_data = d_output->template data<T>(); auto* d_output_data = d_output->template data<T>();
...@@ -219,10 +225,16 @@ void EmbeddingSparseGradKernel(const Context& ctx, ...@@ -219,10 +225,16 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad, const DenseTensor& out_grad,
int64_t padding_idx, int64_t padding_idx,
SelectedRows* weight_grad) { SelectedRows* weight_grad) {
LookupTableV2SparseGradCUDAFunctor<T, Context> functor( EmbeddingSparseGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad); ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
...@@ -15,22 +15,21 @@ ...@@ -15,22 +15,21 @@
#include "paddle/phi/kernels/embedding_kernel.h" #include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi { namespace phi {
template <typename T, typename IdT, bool PaddingFlag> template <typename T, typename IdT, bool PaddingFlag>
__global__ void LookupTableV2(T *output, __global__ void EmbeddingFW(T *output,
const T *table, const T *table,
const IdT *ids, const IdT *ids,
const int64_t N, const int64_t N,
const int64_t K, const int64_t K,
const int64_t D, const int64_t D,
const int64_t padding_idx) { const int64_t padding_idx) {
int idx = threadIdx.x; int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * gridDim.x; int idy = blockIdx.x + threadIdx.y * gridDim.x;
...@@ -53,12 +52,12 @@ __global__ void LookupTableV2(T *output, ...@@ -53,12 +52,12 @@ __global__ void LookupTableV2(T *output,
} }
template <typename T, typename Context> template <typename T, typename Context>
struct LookupTableV2CUDAFunctor { struct EmbeddingCUDAFunctor {
LookupTableV2CUDAFunctor(const Context &dev_ctx, EmbeddingCUDAFunctor(const Context &dev_ctx,
const DenseTensor &input, const DenseTensor &input,
const DenseTensor &weight, const DenseTensor &weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor *out) DenseTensor *out)
: dev_ctx_(dev_ctx), : dev_ctx_(dev_ctx),
input_(input), input_(input),
weight_(weight), weight_(weight),
...@@ -77,14 +76,14 @@ struct LookupTableV2CUDAFunctor { ...@@ -77,14 +76,14 @@ struct LookupTableV2CUDAFunctor {
const T *table = weight_.template data<T>(); const T *table = weight_.template data<T>();
const IdT *ids = input_.template data<IdT>(); const IdT *ids = input_.template data<IdT>();
auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace()); auto *output = dev_ctx_.template Alloc<T>(out_);
auto stream = dev_ctx_.stream(); auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) { if (padding_idx_ == -1) {
LookupTableV2<T, IdT, false><<<grids, threads, 0, stream>>>( EmbeddingFW<T, IdT, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_); output, table, ids, N, K, D, padding_idx_);
} else { } else {
LookupTableV2<T, IdT, true><<<grids, threads, 0, stream>>>( EmbeddingFW<T, IdT, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_); output, table, ids, N, K, D, padding_idx_);
} }
} }
...@@ -103,10 +102,16 @@ void EmbeddingKernel(const Context &ctx, ...@@ -103,10 +102,16 @@ void EmbeddingKernel(const Context &ctx,
const DenseTensor &weight, const DenseTensor &weight,
int64_t padding_idx, int64_t padding_idx,
DenseTensor *out) { DenseTensor *out) {
LookupTableV2CUDAFunctor<T, Context> functor( EmbeddingCUDAFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out); ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor); if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int32_t>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
} }
} // namespace phi } // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册