未验证 提交 7c9abfb2 编写于 作者: R Ruibin Cheung 提交者: GitHub

[Fluid] NO.1 Migrate c_embedding to PHI (#56129)

* [Fluid] Migrate c_embedding to PHI

* fix

* add python_api

* fix ut

* migrate xpu kernel

* fix windows compile error
上级 14b1374f
......@@ -183,18 +183,3 @@ REGISTER_OPERATOR(c_embedding_grad,
ops::CEmbeddingOpGrad,
ops::CEmbeddingGradOpNoBufferVarsInferer,
ops::CEmbeddingOpGradVarTypeInference);
PD_REGISTER_STRUCT_KERNEL(c_embedding,
CPU,
ALL_LAYOUT,
ops::CEmbeddingOpCPUKernel,
float,
double,
plat::float16) {}
PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
CPU,
ALL_LAYOUT,
ops::CEmbeddingGradOpCPUKernel,
float,
double,
plat::float16) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_embedding_op.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
DECLARE_int64(embedding_deterministic);
namespace paddle {
namespace operators {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T, typename IndexT>
__global__ void CEmbedding(T *out,
const T *table,
const IndexT *ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
}
}
}
template <typename T, typename IndexT>
__global__ void CEmbeddingGrad(T *table,
const T *output,
const IndexT *ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]);
}
}
}
template <typename T, typename DeviceContext>
class CEmbeddingCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *table_t = context.Input<phi::DenseTensor>("W");
auto *ids_t = context.Input<phi::DenseTensor>("Ids");
auto *output_t = context.Output<phi::DenseTensor>("Out");
const auto &dev_ctx = context.template device_context<phi::GPUContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
size_t K = ids_t->numel();
const int64_t end_idx = start_idx + N;
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
CEmbedding<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
table,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else if (index_type == framework::proto::VarType::INT64) {
CEmbedding<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(output,
table,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64."));
}
}
};
template <typename T, typename DeviceContext>
class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const auto &dev_ctx = context.template device_context<phi::GPUContext>();
const int64_t start_idx = context.Attr<int64_t>("start_index");
auto ids_t = context.Input<phi::DenseTensor>("Ids");
auto d_output_t =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto d_table_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
int K = ids_t->numel();
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const T *d_output = d_output_t->data<T>();
T *d_table = d_table_t->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (FLAGS_embedding_deterministic == 1) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int32_t>(
dev_ctx,
ids_t->data<int32_t>(),
d_output,
d_table,
N,
D,
K,
start_idx);
return;
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int64_t>(
dev_ctx,
ids_t->data<int64_t>(),
d_output,
d_table,
N,
D,
K,
start_idx);
return;
}
} else {
if (FLAGS_embedding_deterministic > 1) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
}
const int64_t end_idx = start_idx + N;
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int32_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
return;
} else if (index_type == framework::proto::VarType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids_t->data<int64_t>(),
K,
D,
N,
start_idx,
end_idx,
limit);
return;
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Ids) must be int32 or int64."));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(c_embedding,
GPU,
ALL_LAYOUT,
ops::CEmbeddingCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
}
PD_REGISTER_STRUCT_KERNEL(c_embedding_grad,
GPU,
ALL_LAYOUT,
ops::CEmbeddingGradCUDAKernel,
float,
double,
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
plat::bfloat16,
#endif
plat::float16) {
}
......@@ -25,148 +25,16 @@ limitations under the License. */
namespace paddle {
namespace operators {
inline void CheckTableValid() {}
template <typename TIds, typename TData>
void GetIdsEmbedding(const TIds* ids,
size_t ids_len,
int64_t start_idx,
const TData* table,
int64_t height,
int64_t width,
TData* out) {
for (size_t i = 0; i < ids_len; i++) {
TIds id = ids[i];
int64_t local = id - start_idx;
if (local >= 0 && local < height) {
// for (int64_t w = 0; w < width; w++) {
// out[i * width + w] = table[local * width + w];
// }
memcpy(out + i * width, table + local * width, width * sizeof(TData));
} else {
memset(out + i * width, 0, width * sizeof(TData));
}
}
}
template <typename T, typename DeviceContext>
class CEmbeddingOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* table_t = ctx.Input<phi::DenseTensor>("W");
auto* ids_t = ctx.Input<phi::DenseTensor>("Ids");
auto* output_t = ctx.Output<phi::DenseTensor>("Out");
const int64_t start_idx = ctx.Attr<int64_t>("start_index");
VLOG(10) << "table_dims:" << table_t->dims();
const T* table_data = table_t->data<T>();
T* output_data = output_t->mutable_data<T>(ctx.GetPlace());
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
GetIdsEmbedding(ids_t->data<int32_t>(),
ids_t->numel(),
start_idx,
table_data,
height,
width,
output_data);
} else if (index_type == framework::proto::VarType::INT64) {
GetIdsEmbedding(ids_t->data<int64_t>(),
ids_t->numel(),
start_idx,
table_data,
height,
width,
output_data);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"CPU c_embedding ids only support int32 or int64."));
}
}
void Compute(const framework::ExecutionContext& ctx) const override {}
};
template <typename TIds, typename TData>
void UpdateEmbedding(const TIds* ids,
size_t ids_len,
int64_t start_idx,
TData* table,
int64_t height,
int64_t width,
const TData* out) {
for (size_t i = 0; i < ids_len; i++) {
TIds id = ids[i];
int64_t local = id - start_idx;
if (local >= 0 && local < height) {
for (int64_t w = 0; w < width; w++) {
table[local * width + w] += out[i * width + w];
}
}
}
}
template <typename T, typename DeviceContext>
class CEmbeddingGradOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const int64_t start_idx = context.Attr<int64_t>("start_index");
auto ids_t = context.Input<phi::DenseTensor>("Ids");
auto d_output_t =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto table_t = context.Input<phi::DenseTensor>("W");
auto table_grad_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
T* table_grad_data =
table_grad_t->mutable_data<T>(table_t->dims(), context.GetPlace());
size_t table_t_mem_size =
table_t->numel() * phi::SizeOf(table_grad_t->dtype());
size_t table_grad_t_mem_size =
table_grad_t->numel() *
framework::SizeOfType(
framework::TransToProtoVarType(table_grad_t->dtype()));
VLOG(10) << "table_dims:" << table_t->dims()
<< ", table_t memory_size:" << table_t_mem_size
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_idx;
memset(table_grad_data, 0, table_grad_t_mem_size);
const T* d_output_data = d_output_t->data<T>();
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
UpdateEmbedding(ids_t->data<int32_t>(),
ids_t->numel(),
start_idx,
table_grad_data,
height,
width,
d_output_data);
} else if (index_type == framework::proto::VarType::INT64) {
UpdateEmbedding(ids_t->data<int64_t>(),
ids_t->numel(),
start_idx,
table_grad_data,
height,
width,
d_output_data);
} else {
PADDLE_THROW(platform::errors::Unavailable(
"CPU c_embedding ids only support int32 or int64."));
}
}
void Compute(const framework::ExecutionContext& context) const override {}
};
} // namespace operators
......
......@@ -18,129 +18,14 @@ namespace operators {
template <typename T, typename DeviceContext>
class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* table_t = ctx.Input<phi::DenseTensor>("W");
auto* ids_t = ctx.Input<phi::DenseTensor>("Ids");
auto* output_t = ctx.Output<phi::DenseTensor>("Out");
const int64_t start_index = ctx.Attr<int64_t>("start_index");
const T* table_data = table_t->data<T>();
T* output_data = output_t->mutable_data<T>(ctx.GetPlace());
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
// int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm,
// int n, int ym, int padding_idx, TID start_index = 0);
// xm: table height: number of entries of table.
// n: embedding dim: number of float value within single entry.
// ym: number of elements of input ids.
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids_t->data<int32_t>(),
output_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int32_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else if (index_type == framework::proto::VarType::INT64) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids_t->data<int64_t>(),
output_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int64_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else {
PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
}
void Compute(const framework::ExecutionContext& ctx) const override {}
};
template <typename T, typename DeviceContext>
class CEmbeddingGradOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const int64_t start_idx = context.Attr<int64_t>("start_index");
auto ids_t = context.Input<phi::DenseTensor>("Ids");
auto d_output_t =
context.Input<phi::DenseTensor>(framework::GradVarName("Out"));
auto table_t = context.Input<phi::DenseTensor>("W");
auto table_grad_t =
context.Output<phi::DenseTensor>(framework::GradVarName("W"));
auto& dev_ctx = context.template device_context<phi::XPUContext>();
table_grad_t->Resize(table_t->dims());
dev_ctx.template Alloc(table_grad_t, table_t->dtype());
T* table_grad_data = static_cast<T*>(table_grad_t->data());
size_t table_t_mem_size =
table_t->numel() * phi::SizeOf(table_grad_t->dtype());
size_t table_grad_t_mem_size =
table_grad_t->numel() *
framework::SizeOfType(
framework::TransToProtoVarType(table_grad_t->dtype()));
VLOG(10) << "table_dims:" << table_t->dims()
<< ", table_t memory_size:" << table_t_mem_size
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_idx;
int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, table_grad_t->numel(), (T)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = d_output_t->data<T>();
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids_t->data<int32_t>(),
table_grad_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int32_t>(start_idx));
} else if (index_type == framework::proto::VarType::INT64) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids_t->data<int64_t>(),
table_grad_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int64_t>(start_idx));
} else {
PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
}
void Compute(const framework::ExecutionContext& context) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(
c_embedding, XPU, ALL_LAYOUT, ops::CEmbeddingOpXPUKernel, float) {}
PD_REGISTER_STRUCT_KERNEL(
c_embedding_grad, XPU, ALL_LAYOUT, ops::CEmbeddingGradOpXPUKernel, float) {}
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle//phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad);
} // namespace phi
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle//phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void CEmbeddingKernel(const Context& ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename TIds, typename TData>
void UpdateEmbedding(const TIds* ids,
size_t ids_len,
int64_t start_idx,
TData* table,
int64_t height,
int64_t width,
const TData* out) {
for (size_t i = 0; i < ids_len; i++) {
TIds id = ids[i];
int64_t local = id - start_idx;
if (local >= 0 && local < height) {
for (int64_t w = 0; w < width; w++) {
table[local * width + w] += out[i * width + w];
}
}
}
}
template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad) {
w_grad->Resize(w.dims());
T* table_grad_data = dev_ctx.template Alloc<T>(w_grad);
size_t table_t_mem_size = w.numel() * sizeof(w_grad->dtype());
size_t table_grad_t_mem_size = w_grad->numel() * sizeof(w_grad->dtype());
VLOG(10) << "table_dims:" << w.dims()
<< ", table_t memory_size:" << table_t_mem_size
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_index;
memset(table_grad_data, 0, table_grad_t_mem_size);
const T* d_output_data = out_grad.data<T>();
const int64_t height = w.dims()[0];
const int64_t width = w.dims()[1];
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
UpdateEmbedding(ids.data<int32_t>(),
ids.numel(),
start_index,
table_grad_data,
height,
width,
d_output_data);
} else if (index_type == phi::DataType::INT64) {
UpdateEmbedding(ids.data<int64_t>(),
ids.numel(),
start_index,
table_grad_data,
height,
width,
d_output_data);
} else {
PADDLE_THROW(phi::errors::Unavailable(
"CPU c_embedding ids only support int32 or int64."));
}
}
} // namespace phi
PD_REGISTER_KERNEL(c_embedding_grad,
CPU,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename TIds, typename TData>
void GetIdsEmbedding(const TIds* ids,
size_t ids_len,
int64_t start_idx,
const TData* table,
int64_t height,
int64_t width,
TData* out) {
for (size_t i = 0; i < ids_len; i++) {
TIds id = ids[i];
int64_t local = id - start_idx;
if (local >= 0 && local < height) {
memcpy(out + i * width, table + local * width, width * sizeof(TData));
} else {
memset(out + i * width, 0, width * sizeof(TData));
}
}
}
template <typename T, typename Context>
void CEmbeddingKernel(const Context& ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
DenseTensor* out) {
VLOG(10) << "table_dims:" << w.dims();
const T* table_data = w.data<T>();
T* output_data = ctx.template Alloc<T>(out);
const int64_t height = w.dims()[0];
const int64_t width = w.dims()[1];
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
GetIdsEmbedding(ids.data<int32_t>(),
ids.numel(),
start_index,
table_data,
height,
width,
output_data);
} else if (index_type == phi::DataType::INT64) {
GetIdsEmbedding(ids.data<int64_t>(),
ids.numel(),
start_index,
table_data,
height,
width,
output_data);
} else {
PADDLE_THROW(phi::errors::Unavailable(
"CPU c_embedding ids only support int32 or int64."));
}
}
} // namespace phi
PD_REGISTER_KERNEL(c_embedding,
CPU,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "gflags/gflags.h"
#include "glog/logging.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_grad.h"
DECLARE_int64(embedding_deterministic);
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T, typename IndexT>
__global__ void CEmbeddingGrad(T* table,
const T* output,
const IndexT* ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
phi::CudaAtomicAdd(&table[real_idx * columns + col], output[i]);
}
}
}
template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad) {
int N = w_grad->dims()[0];
int D = w_grad->dims()[1];
int K = ids.numel();
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const T* d_output = out_grad.data<T>();
T* d_table = dev_ctx.template Alloc<T>(w_grad);
auto t = EigenVector<T>::Flatten(*w_grad);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto& index_type = ids.dtype();
if (FLAGS_embedding_deterministic == 1) {
if (index_type == phi::DataType::INT32) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int32_t>(
dev_ctx,
ids.data<int32_t>(),
d_output,
d_table,
N,
D,
K,
start_index);
return;
} else if (index_type == phi::DataType::INT64) {
phi::funcs::LaunchEmbeddingGradDeterministicKernel<T, int64_t>(
dev_ctx,
ids.data<int64_t>(),
d_output,
d_table,
N,
D,
K,
start_index);
return;
}
} else {
if (FLAGS_embedding_deterministic > 1) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
}
const int64_t end_idx = start_index + N;
if (index_type == phi::DataType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids.data<int32_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
return;
} else if (index_type == phi::DataType::INT64) {
CEmbeddingGrad<T, int64_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
d_output,
ids.data<int64_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
return;
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Ids) must be int32 or int64."));
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(c_embedding_grad,
GPU,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(c_embedding_grad,
GPU,
ALL_LAYOUT,
phi::CEmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
template <typename T, typename IndexT>
__global__ void CEmbedding(T* out,
const T* table,
const IndexT* ids,
const int rows,
const int columns,
const int64_t N,
const int64_t start_idx,
const int64_t end_idx,
const int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {
size_t row = i / columns;
size_t col = i % columns;
auto id = ids[row];
if (id >= start_idx && id < end_idx) {
auto real_idx = id - start_idx;
PADDLE_ENFORCE(real_idx < N,
"The index is out of bounds, "
"please check whether the dimensions of index and "
"input meet the requirements. It should "
"be less than [%d], but received [%d]",
N,
real_idx);
out[i] = table[real_idx * columns + col];
} else {
out[i] = static_cast<T>(0);
}
}
}
template <typename T, typename Context>
void CEmbeddingKernel(const Context& ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
DenseTensor* out) {
size_t N = w.dims()[0];
size_t D = w.dims()[1];
size_t K = ids.numel();
const int64_t end_idx = start_index + N;
auto* table = w.data<T>();
auto* output = ctx.template Alloc<T>(out);
auto limit = K * D;
int blocks = NumBlocks(limit);
int threads = kNumCUDAThreads;
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
CEmbedding<T, int32_t>
<<<blocks, threads, 0, ctx.stream()>>>(output,
table,
ids.data<int32_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
} else if (index_type == phi::DataType::INT64) {
CEmbedding<T, int64_t>
<<<blocks, threads, 0, ctx.stream()>>>(output,
table,
ids.data<int64_t>(),
K,
D,
N,
start_index,
end_idx,
limit);
} else {
PADDLE_THROW(phi::errors::Unavailable(
"GPU c_embedding ids only support int32 or int64."));
}
}
} // namespace phi
#if NCCL_VERSION_CODE >= 21000 && CUDA_VERSION >= 11000
PD_REGISTER_KERNEL(c_embedding,
GPU,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#else
PD_REGISTER_KERNEL(c_embedding,
GPU,
ALL_LAYOUT,
phi::CEmbeddingKernel,
float,
double,
phi::dtype::float16) {}
#endif
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CEmbeddingKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
int64_t start_index,
DenseTensor* out) {
const T* table_data = w.data<T>();
T* output_data = dev_ctx.template Alloc<T>(out);
const int64_t height = w.dims()[0];
const int64_t width = w.dims()[1];
// int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm,
// int n, int ym, int padding_idx, TID start_index = 0);
// xm: table height: number of entries of table.
// n: embedding dim: number of float value within single entry.
// ym: number of elements of input ids.
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids.data<int32_t>(),
output_data,
height,
width,
ids.numel(),
-1,
static_cast<int32_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else if (index_type == phi::DataType::INT64) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids.data<int64_t>(),
output_data,
height,
width,
ids.numel(),
-1,
static_cast<int64_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else {
PADDLE_THROW(phi::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
}
} // namespace phi
PD_REGISTER_KERNEL(c_embedding, XPU, ALL_LAYOUT, phi::CEmbeddingKernel, float) {
}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/c_embedding_grad_kernel.h"
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void CEmbeddingGradKernel(const Context& dev_ctx,
const DenseTensor& w,
const DenseTensor& ids,
const DenseTensor& out_grad,
int64_t start_index,
DenseTensor* w_grad) {
w_grad->Resize(w.dims());
dev_ctx.template Alloc(w_grad, w.dtype());
T* table_grad_data = static_cast<T*>(w_grad->data());
size_t table_t_mem_size = w.numel() * phi::SizeOf(w_grad->dtype());
size_t table_grad_t_mem_size = w_grad->numel() * phi::SizeOf(w_grad->dtype());
VLOG(10) << "table_dims:" << w.dims()
<< ", table_t memory_size:" << table_t_mem_size
<< ", table_grad_t memory_size:" << table_grad_t_mem_size
<< ", start_index:" << start_index;
int r = xpu::constant(
dev_ctx.x_context(), table_grad_data, w_grad->numel(), (T)0);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
const T* d_output_data = out_grad.data<T>();
const int64_t height = w.dims()[0];
const int64_t width = w.dims()[1];
const auto& index_type = ids.dtype();
if (index_type == phi::DataType::INT32) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids.data<int32_t>(),
table_grad_data,
height,
width,
ids.numel(),
-1,
static_cast<int32_t>(start_index));
} else if (index_type == phi::DataType::INT64) {
r = xpu::embedding_grad(dev_ctx.x_context(),
d_output_data,
ids.data<int64_t>(),
table_grad_data,
height,
width,
ids.numel(),
-1,
static_cast<int64_t>(start_index));
} else {
PADDLE_THROW(phi::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(
c_embedding_grad, XPU, ALL_LAYOUT, phi::CEmbeddingGradKernel, float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature CEmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("c_embedding_grad",
{"W", "Ids", "Out@GRAD"},
{"start_index"},
{"W@GRAD"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(c_embedding_grad,
phi::CEmbeddingGradOpArgumentMapping);
......@@ -17,6 +17,7 @@ import unittest
import numpy as np
from eager_op_test import OpTest
import paddle
from paddle.framework import core
SEED = 2021
......@@ -33,6 +34,12 @@ def get_c_embedding(start, end, table, ids):
return output
def c_embedding_wrapper(table, index, start_index=0):
return paddle._legacy_C_ops.c_embedding(
table, index, "start_index", start_index
)
class TestCEmbeddingCPU(OpTest):
def setUp(self):
self.init_dtype()
......@@ -44,6 +51,7 @@ class TestCEmbeddingCPU(OpTest):
def initcase(self):
self.op_type = "c_embedding"
self.python_api = c_embedding_wrapper
table = np.random.random((17, 64)).astype(self.dtype)
ids = np.random.randint(low=0, high=17 * 2, size=(2, 4)).astype(
self.ids_dtype
......@@ -58,10 +66,10 @@ class TestCEmbeddingCPU(OpTest):
if core.is_compiled_with_xpu():
self.__class__.use_xpu = True
def test_check_cpu(self):
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_cpu_grad(self):
def test_check_grad(self):
self.check_grad_with_place(core.CPUPlace(), ['W'], 'Out')
def init_dtype(self):
......@@ -102,6 +110,7 @@ class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):
def initcase(self):
self.op_type = "c_embedding"
self.python_api = c_embedding_wrapper
table = np.random.random((17, 64)).astype(self.dtype)
ids = np.random.randint(low=0, high=17 * 2, size=(2, 4)).astype(
self.ids_dtype
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册