提交 867fc053 编写于 作者: P phlrain

polish code

上级 7ba14d74
......@@ -15,16 +15,15 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const Context& dev_ctx,
struct EmbeddingGradCPUFunctor {
EmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
......@@ -48,7 +47,6 @@ struct LookupTableV2GradCPUFunctor {
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
// auto d_table = weight_grad_;
auto* ids_data = ids.data();
int64_t N = table_dim[0];
......@@ -70,7 +68,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
......@@ -79,7 +78,8 @@ struct LookupTableV2GradCPUFunctor {
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
......@@ -108,15 +108,20 @@ void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
LookupTableV2GradCPUFunctor<T, Context> functor(
EmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
template <typename T, typename Context>
struct LookupTableV2SparseGradCPUFunctor {
LookupTableV2SparseGradCPUFunctor(const Context& dev_ctx,
struct EmbeddingSparseGradCPUFunctor {
EmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
......@@ -145,7 +150,7 @@ struct LookupTableV2SparseGradCPUFunctor {
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace());
dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]);
......@@ -183,10 +188,15 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
LookupTableV2SparseGradCPUFunctor<T, Context> functor(
EmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
......@@ -15,16 +15,16 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2CPUFunctor {
LookupTableV2CPUFunctor(const Context& dev_ctx,
struct EmbeddingCPUFunctor {
EmbeddingCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
......@@ -91,10 +91,15 @@ void EmbeddingKernel(const Context& ctx,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out) {
LookupTableV2CPUFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
EmbeddingCPUFunctor<T, Context> functor(ctx, input, weight, padding_idx, out);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
......@@ -15,16 +15,16 @@
#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
struct SparseWeightLookupTableV2GradCPUFunctor {
SparseWeightLookupTableV2GradCPUFunctor(const Context& dev_ctx,
struct SparseWeightEmbeddingGradCPUFunctor {
SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
......@@ -70,7 +70,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
......@@ -79,7 +80,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"Variable value (input) of "
"OP(paddle.nn.functional.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
......@@ -102,8 +104,8 @@ struct SparseWeightLookupTableV2GradCPUFunctor {
};
template <typename T, typename Context>
struct SparseWeightLookupTableV2SparseGradCPUFunctor {
SparseWeightLookupTableV2SparseGradCPUFunctor(const Context& dev_ctx,
struct SparseWeightEmbeddingSparseGradCPUFunctor {
SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
......@@ -132,7 +134,7 @@ struct SparseWeightLookupTableV2SparseGradCPUFunctor {
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]});
d_table_value->template mutable_data<T>(dev_ctx_.GetPlace());
dev_ctx_.template Alloc<T>(d_table_value);
d_table->set_height(table_dim[0]);
......@@ -170,10 +172,16 @@ void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
SparseWeightLookupTableV2GradCPUFunctor<T, Context> functor(
SparseWeightEmbeddingGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
template <typename T, typename Context>
......@@ -183,10 +191,16 @@ void SparseWeightEmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
SparseWeightLookupTableV2SparseGradCPUFunctor<T, Context> functor(
SparseWeightEmbeddingSparseGradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
......@@ -15,17 +15,17 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2CPUSparseFunctor {
LookupTableV2CPUSparseFunctor(const Context& dev_ctx,
struct EmbeddingCPUSparseFunctor {
EmbeddingCPUSparseFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
int64_t padding_idx,
......@@ -45,7 +45,7 @@ struct LookupTableV2CPUSparseFunctor {
auto output_t = out_;
int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().template data<T>();
auto* output = output_t->template mutable_data<T>(dev_ctx_.GetPlace());
auto* output = dev_ctx_.template Alloc<T>(output_t);
auto input_data_type =
paddle::framework::TransToProtoVarType(table_t.value().dtype());
......@@ -94,10 +94,16 @@ void SparseWeightEmbeddingKernel(const Context& ctx,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out) {
LookupTableV2CPUSparseFunctor<T, Context> functor(
EmbeddingCPUSparseFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
......@@ -15,9 +15,9 @@
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
......@@ -36,7 +36,7 @@ __global__ void InputTypeConvert(const InT* in_ids,
}
template <typename T, typename IdT>
__global__ void LookupTableV2Grad(T* table,
__global__ void EmbeddingGrad(T* table,
const T* output,
const IdT* ids,
const int64_t N,
......@@ -61,8 +61,8 @@ __global__ void LookupTableV2Grad(T* table,
}
template <typename T, typename Context>
struct LookupTableV2GradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const Context& dev_ctx,
struct EmbeddingGradCUDAFunctor {
EmbeddingGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
......@@ -89,7 +89,7 @@ struct LookupTableV2GradCUDAFunctor {
const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>();
T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace());
T* d_table = dev_ctx_.template Alloc<T>(d_table_t);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(
......@@ -102,7 +102,7 @@ struct LookupTableV2GradCUDAFunctor {
const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8);
dim3 grids(gridx, 1);
LookupTableV2Grad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
......@@ -123,15 +123,21 @@ void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
LookupTableV2GradCUDAFunctor<T, Context> functor(
EmbeddingGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
template <typename T, typename Context>
struct LookupTableV2SparseGradCUDAFunctor {
LookupTableV2SparseGradCUDAFunctor(const Context& dev_ctx,
struct EmbeddingSparseGradCUDAFunctor {
EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
......@@ -179,7 +185,7 @@ struct LookupTableV2SparseGradCUDAFunctor {
auto* d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table->dims()[1]});
d_table_value->template mutable_data<T>(gpu_place);
dev_ctx_.template Alloc<T>(d_table_value);
auto* d_table_data = d_table_value->template data<T>();
auto* d_output_data = d_output->template data<T>();
......@@ -219,10 +225,16 @@ void EmbeddingSparseGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int64_t padding_idx,
SelectedRows* weight_grad) {
LookupTableV2SparseGradCUDAFunctor<T, Context> functor(
EmbeddingSparseGradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
......@@ -15,16 +15,15 @@
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T, typename IdT, bool PaddingFlag>
__global__ void LookupTableV2(T *output,
__global__ void EmbeddingFW(T *output,
const T *table,
const IdT *ids,
const int64_t N,
......@@ -53,8 +52,8 @@ __global__ void LookupTableV2(T *output,
}
template <typename T, typename Context>
struct LookupTableV2CUDAFunctor {
LookupTableV2CUDAFunctor(const Context &dev_ctx,
struct EmbeddingCUDAFunctor {
EmbeddingCUDAFunctor(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &weight,
int64_t padding_idx,
......@@ -77,14 +76,14 @@ struct LookupTableV2CUDAFunctor {
const T *table = weight_.template data<T>();
const IdT *ids = input_.template data<IdT>();
auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace());
auto *output = dev_ctx_.template Alloc<T>(out_);
auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) {
LookupTableV2<T, IdT, false><<<grids, threads, 0, stream>>>(
EmbeddingFW<T, IdT, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
} else {
LookupTableV2<T, IdT, true><<<grids, threads, 0, stream>>>(
EmbeddingFW<T, IdT, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
}
}
......@@ -103,10 +102,16 @@ void EmbeddingKernel(const Context &ctx,
const DenseTensor &weight,
int64_t padding_idx,
DenseTensor *out) {
LookupTableV2CUDAFunctor<T, Context> functor(
EmbeddingCUDAFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
if (input.dtype() == phi::DataType::INT32) {
functor.template apply<int32_t>();
} else if (input.dtype() == phi::DataType::INT64) {
functor.template apply<int64_t>();
} else {
PADDLE_THROW("emebdding input only support int32 and int64");
}
}
} // namespace phi
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册