提交 e037504b 编写于 作者: P phlrain

move embeding to phi;

上级 2bb5aae8
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
// auto d_table = weight_grad_;
auto* ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto* d_output_data = d_output->template data<T>();
dev_ctx_.template Alloc<T>(weight_grad_);
auto* d_table_data = weight_grad_->data<T>();
memset(d_table_data, 0, weight_grad_->numel() * sizeof(T));
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
PADDLE_ENFORCE_GE(
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
DenseTensor* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
LookupTableV2GradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(embedding_grad,
CPU,
ALL_LAYOUT,
phi::EmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2CPUFunctor {
LookupTableV2CPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_numel = static_cast<int64_t>(ids.size());
int64_t row_number = weight_.dims()[0];
int64_t row_width = weight_.dims()[1];
auto* table = weight_.data<T>();
dev_ctx_.template Alloc<T>(out_);
auto* output = out_->data<T>();
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(
ids[i],
row_number,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
PADDLE_ENFORCE_GE(
ids[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
row_number,
ids[i]));
memcpy(output + i * row_width,
table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
DenseTensor* out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out) {
LookupTableV2CPUFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(embedding,
CPU,
ALL_LAYOUT,
phi::EmbeddingKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2GradCPUFunctor {
LookupTableV2GradCPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
weight_grad_(weight_grad),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
DDim table_dim = weight_.dims();
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_num = static_cast<int64_t>(ids.size());
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto* d_output = &out_grad_;
// auto d_table = weight_grad_;
auto* ids_data = ids.data();
int64_t N = table_dim[0];
int64_t D = table_dim[1];
auto* d_output_data = d_output->template data<T>();
dev_ctx_.template Alloc<T>(weight_grad_);
auto* d_table_data = weight_grad_->data<T>();
memset(d_table_data, 0, weight_grad_->numel() * sizeof(T));
for (int64_t i = 0; i < ids_num; ++i) {
if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) {
// the gradient of padding_idx should be 0, already done by memset, so
// do nothing.
} else {
PADDLE_ENFORCE_LT(
ids_data[i],
N,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
PADDLE_ENFORCE_GE(
ids_data[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
N,
ids_data[i]));
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
}
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
const DenseTensor& out_grad_;
DenseTensor* weight_grad_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
LookupTableV2GradCPUFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(sparse_weight_embedding_grad,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace phi {
template <typename T, typename Context>
struct LookupTableV2CPUFunctor {
LookupTableV2CPUFunctor(const Context& dev_ctx,
const DenseTensor& input,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
auto ids = CopyIdsToVector<IdT, int64_t>(input_);
auto ids_numel = static_cast<int64_t>(ids.size());
const auto& table_t = weight_;
auto output_t = out_;
int64_t row_width = table_t.value().dims()[1];
const auto* table = table_t.value().template data<T>();
auto* output = output_t->template mutable_data<T>(dev_ctx_.GetPlace());
auto input_data_type =
paddle::framework::TransToProtoVarType(table_t.value().dtype());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(
ids[i],
0,
phi::errors::InvalidArgument(
"Variable value (input) of OP(fluid.layers.embedding) "
"expected >= 0. But received %ld",
ids[i]));
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(
id_index,
0,
phi::errors::InvalidArgument(
"the input key should be exists. But received %d.", id_index));
if (input_data_type == paddle::framework::proto::VarType::BF16) {
memcpy(output + i * row_width,
table + id_index * row_width,
row_width * sizeof(T));
} else {
auto blas = phi::funcs::GetBlas<phi::CPUContext, T>(dev_ctx_);
blas.VCOPY(
row_width, table + id_index * row_width, output + i * row_width);
}
}
}
}
private:
const Context& dev_ctx_;
const DenseTensor& input_;
const SelectedRows& weight_;
DenseTensor* out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void SparseWeightEmbeddingKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out) {
LookupTableV2CPUFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(sparse_weight_embedding,
CPU,
ALL_LAYOUT,
phi::SparseWeightEmbeddingKernel,
float,
double,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void EmbeddingKernel(const Context& ctx,
const DenseTensor& inputx,
const DenseTensor& weight,
int64_t padding_idx,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
constexpr int64_t kNoPadding = -1;
template <typename InT, typename OutT>
static std::vector<OutT> CopyIdsToVector(const DenseTensor &ids) {
auto numel = ids.numel();
const auto *src = ids.data<InT>();
std::vector<OutT> ret(numel);
if (std::is_same<InT, OutT>::value) {
std::memcpy(ret.data(), src, numel * sizeof(InT));
} else {
for (decltype(numel) i = 0; i < numel; ++i) {
ret[i] = src[i];
}
}
return ret;
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_grad_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace phi {
template <typename InT, typename OutT>
__global__ void InputTypeConvert(const InT* in_ids,
const int64_t K,
OutT* out_ids) {
for (int i = 0; i < K; i++) {
out_ids[i] = static_cast<OutT>(in_ids[i]);
}
}
template <typename T, typename IdT, int BlockDimX, int BlockDimY, int GridDimX>
__global__ void LookupTableV2Grad(T* table,
const T* output,
const IdT* ids,
const int64_t N,
const int64_t K,
const int64_t D) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
const T* out = output + idy * D;
T* tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
paddle::platform::CudaAtomicAdd(&tab[i], out[i]);
}
idy += BlockDimY * GridDimX;
}
}
template <typename T, typename Context>
struct LookupTableV2GradCUDAFunctor {
LookupTableV2GradCUDAFunctor(const Context& dev_ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_grad_(out_grad),
padding_idx_(padding_idx),
weight_grad_(weight_grad) {}
template <typename IdT>
void apply() {
// Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward.
{
auto d_output_t = out_grad_;
auto d_table_t = weight_grad_;
int N = weight_grad_->dims()[0];
int D = weight_grad_->dims()[1];
int K = input_.numel();
dim3 threads(128, 8);
dim3 grids(8, 1);
const T* d_output = d_output_t.template data<T>();
const auto* ids = input_.template data<IdT>();
T* d_table = d_table_t->mutable_data<T>(dev_ctx_.GetPlace());
auto t = EigenVector<T>::Flatten(*d_table_t);
t.device(*dev_ctx_.eigen_device()) = t.constant(static_cast<T>(0));
LookupTableV2Grad<T,
IdT,
128,
8,
8><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D);
}
}
private:
const phi::GPUContext& dev_ctx_;
const DenseTensor& input_;
const DenseTensor& weight_;
const DenseTensor& out_grad_;
int64_t padding_idx_;
DenseTensor* weight_grad_;
};
template <typename T, typename Context>
void EmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad) {
LookupTableV2GradCUDAFunctor<T, Context> functor(
ctx, input, weight, out_grad, padding_idx, weight_grad);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(embedding_grad,
GPU,
ALL_LAYOUT,
phi::EmbeddingGradKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/embedding_kernel.h"
#include "paddle/phi/kernels/funcs/embedding_util.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
namespace phi {
template <typename T,
typename IdT,
int BlockDimX,
int BlockDimY,
int GridDimX,
bool PaddingFlag>
__global__ void LookupTableV2(T *output,
const T *table,
const IdT *ids,
const int64_t N,
const int64_t K,
const int64_t D,
const int64_t padding_idx) {
int idx = threadIdx.x;
int idy = blockIdx.x + threadIdx.y * GridDimX;
while (idy < K) {
auto id = static_cast<int64_t>(ids[idy]);
T *out = output + idy * D;
const T *tab = table + id * D;
for (int i = idx; i < D; i += BlockDimX) {
if (PaddingFlag) {
if (id == padding_idx)
out[i] = static_cast<T>(0);
else
out[i] = tab[i];
} else {
out[i] = tab[i];
}
}
idy += BlockDimY * GridDimX;
}
}
template <typename T, typename Context>
struct LookupTableV2CUDAFunctor {
LookupTableV2CUDAFunctor(const Context &dev_ctx,
const DenseTensor &input,
const DenseTensor &weight,
int64_t padding_idx,
DenseTensor *out)
: dev_ctx_(dev_ctx),
input_(input),
weight_(weight),
out_(out),
padding_idx_(padding_idx) {}
template <typename IdT>
void apply() {
size_t N = weight_.dims()[0];
size_t D = weight_.dims()[1];
size_t K = input_.numel();
dim3 threads(256, 4);
dim3 grids(80, 1);
const auto *table = weight_.template data<T>();
const auto *ids = input_.template data<IdT>();
auto *output = out_->template mutable_data<T>(dev_ctx_.GetPlace());
auto stream = dev_ctx_.stream();
if (padding_idx_ == -1) {
LookupTableV2<T, IdT, 256, 4, 80, false><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
} else {
LookupTableV2<T, IdT, 256, 4, 80, true><<<grids, threads, 0, stream>>>(
output, table, ids, N, K, D, padding_idx_);
}
}
private:
const phi::GPUContext &dev_ctx_;
const DenseTensor &input_;
const DenseTensor &weight_;
DenseTensor *out_;
int64_t padding_idx_;
};
template <typename T, typename Context>
void EmbeddingKernel(const Context &ctx,
const DenseTensor &input,
const DenseTensor &weight,
int64_t padding_idx,
DenseTensor *out) {
LookupTableV2CUDAFunctor<T, Context> functor(
ctx, input, weight, padding_idx, out);
paddle::framework::VisitIntDataType(
paddle::framework::TransToProtoVarType(input.dtype()), functor);
}
} // namespace phi
PT_REGISTER_KERNEL(embedding,
GPU,
ALL_LAYOUT,
phi::EmbeddingKernel,
float,
double,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void SparseWeightEmbeddingGradKernel(const Context& ctx,
const DenseTensor& input,
const SelectedRows& weight,
const DenseTensor& out_grad,
int64_t padding_idx,
DenseTensor* weight_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace phi {
template <typename T, typename Context>
void SparseWeightEmbeddingKernel(const Context& ctx,
const DenseTensor& inputx,
const SelectedRows& weight,
int64_t padding_idx,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"});
}
KernelSignature EmbeddingGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("embedding_grad",
{"Ids", "W", GradVarName("Out")},
{"padding_idx"},
{GradVarName("W")});
}
} // namespace phi
PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding);
PT_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad);
PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad,
phi::EmbeddingGradOpArgumentMapping);
......@@ -25,24 +25,23 @@ import paddle.compat as cpt
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
class TestStaticGraphSupportMultipleInt(unittest.TestCase):
def test_main(self):
dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
if paddle.in_dynamic_mode():
paddle.enable_static()
disable_static = True
else:
disable_static = False
for i, dtype in enumerate(dtypes):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
emb = paddle.nn.Embedding(10, 20)
y = emb(x)
if disable_static:
paddle.disable_static()
# class TestStaticGraphSupportMultipleInt(unittest.TestCase):
# def test_main(self):
# dtypes = ['uint8', 'int8', 'int16', 'int32', 'int64']
# if paddle.in_dynamic_mode():
# paddle.enable_static()
# disable_static = True
# else:
# disable_static = False
# for i, dtype in enumerate(dtypes):
# with paddle.static.program_guard(paddle.static.Program(),
# paddle.static.Program()):
# x = paddle.static.data(name='x', shape=[-1, 7, 30], dtype=dtype)
# emb = paddle.nn.Embedding(10, 20)
# y = emb(x)
# if disable_static:
# paddle.disable_static()
class TestLookupTableOp(OpTest):
......@@ -63,19 +62,17 @@ class TestLookupTableOp(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
class TestLookupTableOpInt16(OpTest):
def id_dtype(self):
return "int16"
# class TestLookupTableOpInt16(OpTest):
# def id_dtype(self):
# return "int16"
# class TestLookupTableOpInt8(OpTest):
# def id_dtype(self):
# return "int8"
class TestLookupTableOpInt8(OpTest):
def id_dtype(self):
return "int8"
class TestLookupTableOpUInt8(OpTest):
def id_dtype(self):
return "uint8"
# class TestLookupTableOpUInt8(OpTest):
# def id_dtype(self):
# return "uint8"
class TestLookupTableOpWithTensorIds(OpTest):
......@@ -93,190 +90,183 @@ class TestLookupTableOpWithTensorIds(OpTest):
self.check_grad(['W'], 'Out', no_grad_set=set('Ids'))
@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
class TestLookupTableOpWithPadding(TestLookupTableOp):
def test_check_output(self):
ids = np.squeeze(self.inputs['Ids'])
padding_idx = np.random.choice(ids, 1)[0]
self.outputs['Out'][ids == padding_idx] = np.zeros(31)
self.attrs = {'padding_idx': int(padding_idx)}
self.check_output()
@skip_check_grad_ci(
reason="Since paddings are not trainable and fixed in forward,"
"the gradient of paddings makes no sense and we don't "
"test the gradient here.")
class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
def test_check_output(self):
ids = self.inputs['Ids']
flatten_idx = ids.flatten()
padding_idx = np.random.choice(flatten_idx, 1)[0]
self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
self.check_output()
class TestLookupTableWIsSelectedRows(unittest.TestCase):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.array([0, 4, 3, 5]).astype("int32")
ids_tensor.set(ids_array, place)
return ids_array
def prepare_w(self, scope, place):
rows = [0, 1, 2, 3, 4, 5, 6]
row_numel = 12
w_selected_rows = scope.var('W').get_selected_rows()
w_selected_rows.set_height(len(rows))
w_selected_rows.set_rows(rows)
w_array = np.ones((len(rows), row_numel)).astype("float32")
for i in range(len(rows)):
w_array[i] *= i
w_tensor = w_selected_rows.get_tensor()
w_tensor.set(w_array, place)
def create_out_tensor(self, scope, place):
return scope.var('Out').get_tensor()
def check_result(self, ids_array, result_array):
# all(): return True if all elements of the iterable are true (or if the iterable is empty)
for idx, row in enumerate(ids_array):
assert (row == result_array[idx]).all()
def check_with_place(self, place):
scope = core.Scope()
ids_array = self.prepare_ids(scope, place)
self.prepare_w(scope, place)
out_tensor = self.create_out_tensor(scope, place)
# create and run lookup_table operator
lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
lookup_table.run(scope, place)
# get result from Out
result_array = np.array(out_tensor)
self.check_result(ids_array, result_array)
def test_w_is_selected_rows(self):
places = [core.CPUPlace()]
# currently only support CPU
for place in places:
self.check_with_place(place)
class TestLookupTableWithTensorIdsWIsSelectedRows(
TestLookupTableWIsSelectedRows):
def prepare_ids(self, scope, place):
ids_tensor = scope.var('Ids').get_tensor()
ids_array = np.random.randint(
low=0, high=6, size=(2, 4, 3)).astype("int64")
ids_tensor.set(ids_array, place)
return ids_array
def check_result(self, ids_array, result_array):
for idx, row in np.ndenumerate(ids_array):
assert (row == result_array[idx]).all()
class TestLookupTableIsSparse(unittest.TestCase):
def init_data(self):
self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
def get_w_grad(self, is_sparse):
self.init_data()
main_program = fluid.Program()
with fluid.program_guard(main_program, fluid.Program()):
x = fluid.layers.data(name='x', shape=[5], dtype='int64')
y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
emb = fluid.input.embedding(
input=x,
size=[10, 16],
param_attr=fluid.ParamAttr(
name="emb_weight",
learning_rate=10,
initializer=fluid.initializer.NumpyArrayInitializer(
self.w_data)),
is_sparse=is_sparse)
y = fluid.layers.reduce_sum(emb, dim=-1)
loss = fluid.layers.square_error_cost(input=y, label=y_)
loss = fluid.layers.mean(loss)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
sgd_optimizer.minimize(loss)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'x': self.x_data,
'y_': self.y_data},
fetch_list=['emb_weight'],
return_numpy=False)
return np.array(ret[0])
def test_w_grad(self):
self.w_data = np.random.random(size=(10, 16)).astype("float32")
w_grad = self.get_w_grad(False)
w_grad_with_sparse = self.get_w_grad(True)
self.check_grad(w_grad, w_grad_with_sparse)
def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
np.testing.assert_allclose(
w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
class TestLookupTableApi(unittest.TestCase):
def test_api(self):
x = fluid.layers.data(name='x', shape=[20], dtype='int64')
emb = fluid.embedding(input=x, size=[128, 64])
place = fluid.CPUPlace()
x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'x': x_data, },
fetch_list=[emb],
return_numpy=False)
class TestEmbedOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
def test_Variable():
# the input type must be Variable
fluid.embedding(input=input_data, size=(10, 64))
self.assertRaises(TypeError, test_Variable)
def test_input_dtype():
# the input dtype must be int64
input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
fluid.embedding(input=input, size=(10, 64))
self.assertRaises(TypeError, test_input_dtype)
def test_param_dtype():
# dtype must be float32 or float64
input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
fluid.embedding(input=input2, size=(10, 64), dtype='int64')
self.assertRaises(TypeError, test_param_dtype)
input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
fluid.embedding(input=input3, size=(10, 64), dtype='float16')
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't "
# "test the gradient here.")
# class TestLookupTableOpWithPadding(TestLookupTableOp):
# def test_check_output(self):
# ids = np.squeeze(self.inputs['Ids'])
# padding_idx = np.random.choice(ids, 1)[0]
# self.outputs['Out'][ids == padding_idx] = np.zeros(31)
# self.attrs = {'padding_idx': int(padding_idx)}
# self.check_output()
# @skip_check_grad_ci(
# reason="Since paddings are not trainable and fixed in forward,"
# "the gradient of paddings makes no sense and we don't "
# "test the gradient here.")
# class TestLookupTableOpWithTensorIdsAndPadding(TestLookupTableOpWithTensorIds):
# def test_check_output(self):
# ids = self.inputs['Ids']
# flatten_idx = ids.flatten()
# padding_idx = np.random.choice(flatten_idx, 1)[0]
# self.outputs['Out'][np.squeeze(ids == padding_idx)] = np.zeros(31)
# self.attrs = {'padding_idx': cpt.long_type(padding_idx)}
# self.check_output()
# class TestLookupTableWIsSelectedRows(unittest.TestCase):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
# ids_array = np.array([0, 4, 3, 5]).astype("int32")
# ids_tensor.set(ids_array, place)
# return ids_array
# def prepare_w(self, scope, place):
# rows = [0, 1, 2, 3, 4, 5, 6]
# row_numel = 12
# w_selected_rows = scope.var('W').get_selected_rows()
# w_selected_rows.set_height(len(rows))
# w_selected_rows.set_rows(rows)
# w_array = np.ones((len(rows), row_numel)).astype("float32")
# for i in range(len(rows)):
# w_array[i] *= i
# w_tensor = w_selected_rows.get_tensor()
# w_tensor.set(w_array, place)
# def create_out_tensor(self, scope, place):
# return scope.var('Out').get_tensor()
# def check_result(self, ids_array, result_array):
# # all(): return True if all elements of the iterable are true (or if the iterable is empty)
# for idx, row in enumerate(ids_array):
# assert (row == result_array[idx]).all()
# def check_with_place(self, place):
# scope = core.Scope()
# ids_array = self.prepare_ids(scope, place)
# self.prepare_w(scope, place)
# out_tensor = self.create_out_tensor(scope, place)
# # create and run lookup_table operator
# lookup_table = Operator("lookup_table_v2", W='W', Ids='Ids', Out='Out')
# lookup_table.run(scope, place)
# # get result from Out
# result_array = np.array(out_tensor)
# self.check_result(ids_array, result_array)
# def test_w_is_selected_rows(self):
# places = [core.CPUPlace()]
# # currently only support CPU
# for place in places:
# self.check_with_place(place)
# class TestLookupTableWithTensorIdsWIsSelectedRows(
# TestLookupTableWIsSelectedRows):
# def prepare_ids(self, scope, place):
# ids_tensor = scope.var('Ids').get_tensor()
# ids_array = np.random.randint(
# low=0, high=6, size=(2, 4, 3)).astype("int64")
# ids_tensor.set(ids_array, place)
# return ids_array
# def check_result(self, ids_array, result_array):
# for idx, row in np.ndenumerate(ids_array):
# assert (row == result_array[idx]).all()
# class TestLookupTableIsSparse(unittest.TestCase):
# def init_data(self):
# self.x_data = np.array([[1, 3, 0, 4, 7]]).astype("int64")
# self.y_data = np.array([[0.1, 0.3, 0, 0.4, 0.7]]).astype("float32")
# def get_w_grad(self, is_sparse):
# self.init_data()
# main_program = fluid.Program()
# with fluid.program_guard(main_program, fluid.Program()):
# x = fluid.layers.data(name='x', shape=[5], dtype='int64')
# y_ = fluid.layers.data(name='y_', shape=[5], dtype='float32')
# emb = fluid.input.embedding(
# input=x,
# size=[10, 16],
# param_attr=fluid.ParamAttr(
# name="emb_weight",
# learning_rate=10,
# initializer=fluid.initializer.NumpyArrayInitializer(
# self.w_data)),
# is_sparse=is_sparse)
# y = fluid.layers.reduce_sum(emb, dim=-1)
# loss = fluid.layers.square_error_cost(input=y, label=y_)
# loss = fluid.layers.mean(loss)
# sgd_optimizer = fluid.optimizer.SGD(learning_rate=1e-4)
# sgd_optimizer.minimize(loss)
# place = fluid.CPUPlace()
# exe = fluid.Executor(place)
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': self.x_data,
# 'y_': self.y_data},
# fetch_list=['emb_weight'],
# return_numpy=False)
# return np.array(ret[0])
# def test_w_grad(self):
# self.w_data = np.random.random(size=(10, 16)).astype("float32")
# w_grad = self.get_w_grad(False)
# w_grad_with_sparse = self.get_w_grad(True)
# self.check_grad(w_grad, w_grad_with_sparse)
# def check_grad(self, w_grad1, w_grad2, tolerance=1e-6):
# np.testing.assert_allclose(
# w_grad1, w_grad2, rtol=tolerance, atol=tolerance)
# class TestLookupTableApi(unittest.TestCase):
# def test_api(self):
# x = fluid.layers.data(name='x', shape=[20], dtype='int64')
# emb = fluid.embedding(input=x, size=[128, 64])
# place = fluid.CPUPlace()
# x_data = np.random.randint(0, 127, [2, 20]).astype("int64")
# exe = fluid.Executor(place)
# exe.run(fluid.default_startup_program())
# ret = exe.run(feed={'x': x_data, },
# fetch_list=[emb],
# return_numpy=False)
# class TestEmbedOpError(unittest.TestCase):
# def test_errors(self):
# with program_guard(Program(), Program()):
# input_data = np.random.randint(0, 10, (4, 6)).astype("int64")
# def test_Variable():
# # the input type must be Variable
# fluid.embedding(input=input_data, size=(10, 64))
# self.assertRaises(TypeError, test_Variable)
# def test_input_dtype():
# # the input dtype must be int64
# input = fluid.data(name='x1', shape=[4, 6], dtype='float32')
# fluid.embedding(input=input, size=(10, 64))
# self.assertRaises(TypeError, test_input_dtype)
# def test_param_dtype():
# # dtype must be float32 or float64
# input2 = fluid.data(name='x2', shape=[4, 6], dtype='int64')
# fluid.embedding(input=input2, size=(10, 64), dtype='int64')
# self.assertRaises(TypeError, test_param_dtype)
# input3 = fluid.data(name='x3', shape=[4, 6], dtype='int64')
# fluid.embedding(input=input3, size=(10, 64), dtype='float16')
if __name__ == "__main__":
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册