diff --git a/paddle/fluid/operators/lookup_table_v2_op.cc b/paddle/fluid/operators/lookup_table_v2_op.cc index 47a00a93a647253305080bde2d8c98eb735513d6..48ae080783d112c7e11daebe984de70925f5bbe2 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cc +++ b/paddle/fluid/operators/lookup_table_v2_op.cc @@ -203,14 +203,6 @@ REGISTER_OPERATOR(lookup_table_v2_grad, ops::LookupTableV2OpGrad, ops::LookupTableV2GradOpNoBufferVarsInferer, ops::LookupTableV2OpGradVarTypeInference); -REGISTER_OP_CPU_KERNEL(lookup_table_v2, ops::LookupTableV2Kernel, - ops::LookupTableV2Kernel, - ops::LookupTableV2Kernel); -REGISTER_OP_CPU_KERNEL( - lookup_table_v2_grad, ops::LookupTableV2GradKernel, - ops::LookupTableV2GradKernel, - ops::LookupTableV2GradKernel); - /* ========================== register checkpoint ===========================*/ REGISTER_OP_VERSION(lookup_table_v2) .AddCheckpoint( diff --git a/paddle/fluid/operators/lookup_table_v2_op.cu b/paddle/fluid/operators/lookup_table_v2_op.cu index d40b2643785706e843dbd9812e74ca0aa134f7b5..74d089e23a82c6ea15988d9c3c0c3e5b42da8b2e 100644 --- a/paddle/fluid/operators/lookup_table_v2_op.cu +++ b/paddle/fluid/operators/lookup_table_v2_op.cu @@ -235,13 +235,3 @@ class LookupTableV2GradCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL(lookup_table_v2, ops::LookupTableV2CUDAKernel, - ops::LookupTableV2CUDAKernel, - ops::LookupTableV2CUDAKernel); -REGISTER_OP_CUDA_KERNEL(lookup_table_v2_grad, - ops::LookupTableV2GradCUDAKernel, - ops::LookupTableV2GradCUDAKernel, - ops::LookupTableV2GradCUDAKernel); diff --git a/paddle/phi/kernels/cpu/embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..21b3e6da8d9efdac1e5866ef3ac1aac580d5a0b8 --- /dev/null +++ b/paddle/phi/kernels/cpu/embedding_grad_kernel.cc @@ -0,0 +1,220 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +struct EmbeddingGradCPUFunctor { + EmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto* d_output = &out_grad_; + auto* ids_data = ids.data(); + + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; + + auto* d_output_data = d_output->template data(); + + dev_ctx_.template Alloc(weight_grad_); + auto* d_table_data = weight_grad_->data(); + + memset(d_table_data, 0, weight_grad_->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids_num; ++i) { + if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) { + // the gradient of padding_idx should be 0, already done by memset, so + // do nothing. + } else { + PADDLE_ENFORCE_LT( + ids_data[i], + N, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + PADDLE_ENFORCE_GE( + ids_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + DenseTensor* weight_grad_; + int64_t padding_idx_; +}; + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + EmbeddingGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +struct EmbeddingSparseGradCPUFunctor { + EmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + dev_ctx_.template Alloc(d_table_value); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + EmbeddingSparseGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding_grad, + CPU, + ALL_LAYOUT, + phi::EmbeddingGradKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/embedding_kernel.cc b/paddle/phi/kernels/cpu/embedding_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..76cc3814b0567087ef8e5d40fe4031ed6598a49b --- /dev/null +++ b/paddle/phi/kernels/cpu/embedding_kernel.cc @@ -0,0 +1,114 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +struct EmbeddingCPUFunctor { + EmbeddingCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + auto ids = CopyIdsToVector(input_); + auto ids_numel = static_cast(ids.size()); + + int64_t row_number = weight_.dims()[0]; + int64_t row_width = weight_.dims()[1]; + + auto* table = weight_.data(); + + dev_ctx_.template Alloc(out_); + auto* output = out_->data(); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT( + ids[i], + row_number, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, + ids[i])); + PADDLE_ENFORCE_GE( + ids[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + row_number, + ids[i])); + memcpy(output + i * row_width, + table + ids[i] * row_width, + row_width * sizeof(T)); + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + DenseTensor* out_; + int64_t padding_idx_; +}; + +template +void EmbeddingKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out) { + EmbeddingCPUFunctor functor(ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding, + CPU, + ALL_LAYOUT, + phi::EmbeddingKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..d78477073ad03b1b39aaae00c16aed81ea7fd056 --- /dev/null +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_grad_kernel.cc @@ -0,0 +1,224 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" + +namespace phi { + +template +struct SparseWeightEmbeddingGradCPUFunctor { + SparseWeightEmbeddingGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto* d_output = &out_grad_; + // auto d_table = weight_grad_; + auto* ids_data = ids.data(); + + int64_t N = table_dim[0]; + int64_t D = table_dim[1]; + + auto* d_output_data = d_output->template data(); + + dev_ctx_.template Alloc(weight_grad_); + auto* d_table_data = weight_grad_->data(); + + memset(d_table_data, 0, weight_grad_->numel() * sizeof(T)); + + for (int64_t i = 0; i < ids_num; ++i) { + if (padding_idx_ != kNoPadding && ids_data[i] == padding_idx_) { + // the gradient of padding_idx should be 0, already done by memset, so + // do nothing. + } else { + PADDLE_ENFORCE_LT( + ids_data[i], + N, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + PADDLE_ENFORCE_GE( + ids_data[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of " + "OP(paddle.nn.functional.embedding) " + "expected >= 0 and < %ld, but got %ld. Please check input " + "value.", + N, + ids_data[i])); + for (int j = 0; j < D; ++j) { + d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j]; + } + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + const DenseTensor& out_grad_; + DenseTensor* weight_grad_; + int64_t padding_idx_; +}; + +template +struct SparseWeightEmbeddingSparseGradCPUFunctor { + SparseWeightEmbeddingSparseGradCPUFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + weight_grad_(weight_grad), + padding_idx_(padding_idx) {} + + template + void apply() { + DDim table_dim = weight_.dims(); + + auto ids = CopyIdsToVector(input_); + auto ids_num = static_cast(ids.size()); + + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + auto* d_table = weight_grad_; + auto* d_output = &out_grad_; + d_table->set_rows(ids); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table_dim[1]}); + + dev_ctx_.template Alloc(d_table_value); + + d_table->set_height(table_dim[0]); + + auto* d_output_data = d_output->template data(); + auto* d_table_data = d_table_value->template data(); + + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel()); + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + const DenseTensor& out_grad_; + SelectedRows* weight_grad_; + int64_t padding_idx_; +}; + +template +void SparseWeightEmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + SparseWeightEmbeddingGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + SparseWeightEmbeddingSparseGradCPUFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(sparse_weight_embedding_grad, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingGradKernel, + float, + double, + phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL(sparse_weight_embedding_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingSparseGradKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0f95d03888b8df825341c282e08f80dafc988a8 --- /dev/null +++ b/paddle/phi/kernels/cpu/sparse_weight_embedding_kernel.cc @@ -0,0 +1,118 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/utils/data_type.h" +#include "paddle/phi/kernels/funcs/blas/blas.h" + +namespace phi { + +template +struct EmbeddingCPUSparseFunctor { + EmbeddingCPUSparseFunctor(const Context& dev_ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + auto ids = CopyIdsToVector(input_); + auto ids_numel = static_cast(ids.size()); + + const auto& table_t = weight_; + auto output_t = out_; + int64_t row_width = table_t.value().dims()[1]; + const auto* table = table_t.value().template data(); + auto* output = dev_ctx_.template Alloc(output_t); + auto input_data_type = + paddle::framework::TransToProtoVarType(table_t.value().dtype()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx_ != kNoPadding && ids[i] == padding_idx_) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE( + ids[i], + 0, + phi::errors::InvalidArgument( + "Variable value (input) of OP(fluid.layers.embedding) " + "expected >= 0. But received %ld", + ids[i])); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE( + id_index, + 0, + phi::errors::InvalidArgument( + "the input key should be exists. But received %d.", id_index)); + + if (input_data_type == paddle::framework::proto::VarType::BF16) { + memcpy(output + i * row_width, + table + id_index * row_width, + row_width * sizeof(T)); + } else { + auto blas = phi::funcs::GetBlas(dev_ctx_); + blas.VCOPY( + row_width, table + id_index * row_width, output + i * row_width); + } + } + } + } + + private: + const Context& dev_ctx_; + const DenseTensor& input_; + const SelectedRows& weight_; + DenseTensor* out_; + int64_t padding_idx_; +}; + +template +void SparseWeightEmbeddingKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out) { + EmbeddingCPUSparseFunctor functor( + ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(sparse_weight_embedding, + CPU, + ALL_LAYOUT, + phi::SparseWeightEmbeddingKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/embedding_grad_kernel.h b/paddle/phi/kernels/embedding_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..40ffe6ec886c447a1d5f762cdbe01c95edb39764 --- /dev/null +++ b/paddle/phi/kernels/embedding_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad); + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/embedding_kernel.h b/paddle/phi/kernels/embedding_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cd7d675d6dc6cd7d71486437d9c56c4e73431af1 --- /dev/null +++ b/paddle/phi/kernels/embedding_kernel.h @@ -0,0 +1,28 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void EmbeddingKernel(const Context& ctx, + const DenseTensor& inputx, + const DenseTensor& weight, + int64_t padding_idx, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/funcs/embedding_util.h b/paddle/phi/kernels/funcs/embedding_util.h new file mode 100644 index 0000000000000000000000000000000000000000..20c4ddca05460afbbd30491c9269ce935a3f611e --- /dev/null +++ b/paddle/phi/kernels/funcs/embedding_util.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { +constexpr int64_t kNoPadding = -1; + +template +static std::vector CopyIdsToVector(const DenseTensor &ids) { + auto numel = ids.numel(); + const auto *src = ids.data(); + std::vector ret(numel); + if (std::is_same::value) { + std::memcpy(ret.data(), src, numel * sizeof(InT)); + } else { + for (decltype(numel) i = 0; i < numel; ++i) { + ret[i] = src[i]; + } + } + return ret; +} + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..a970348760c18e2c67e9c7b366cdc2f5e18e3abd --- /dev/null +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -0,0 +1,258 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_grad_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/fluid/memory/memcpy.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" + +namespace phi { + +template +__global__ void InputTypeConvert(const InT* in_ids, + const int64_t K, + OutT* out_ids) { + for (int i = 0; i < K; i++) { + out_ids[i] = static_cast(in_ids[i]); + } +} + +template +__global__ void EmbeddingGrad(T* table, + const T* output, + const IdT* ids, + const int64_t N, + const int64_t K, + const int64_t D) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDim.x; + + while (idy < K) { + auto id = static_cast(ids[idy]); + const T* out = output + idy * D; + T* tab = table + id * D; +#ifdef PADDLE_WITH_CUDA + paddle::platform::VectorizedAtomicAddPerBlock(D, idx, blockDim.x, out, tab); +#else + for (int i = idx; i < D; i += blockDim.x) { + paddle::platform::CudaAtomicAdd(&tab[i], out[i]); + } +#endif + idy += blockDim.y * gridDim.x; + } +} + +template +struct EmbeddingGradCUDAFunctor { + EmbeddingGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + padding_idx_(padding_idx), + weight_grad_(weight_grad) {} + + template + void apply() { + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + { + auto d_output_t = out_grad_; + auto d_table_t = weight_grad_; + + int N = weight_grad_->dims()[0]; + int D = weight_grad_->dims()[1]; + int K = input_.numel(); + + const T* d_output = d_output_t.template data(); + const auto* ids = input_.template data(); + T* d_table = dev_ctx_.template Alloc(d_table_t); + +#ifdef PADDLE_WITH_HIP + PADDLE_ENFORCE_GPU_SUCCESS( + hipMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#else + PADDLE_ENFORCE_GPU_SUCCESS( + cudaMemsetAsync(d_table, 0, N * D * sizeof(T), dev_ctx_.stream())); +#endif + + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(128, 8); + dim3 grids(gridx, 1); + EmbeddingGrad<<>>( + d_table, d_output, ids, N, K, D); + } + } + + private: + const phi::GPUContext& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + int64_t padding_idx_; + DenseTensor* weight_grad_; +}; + +template +void EmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad) { + EmbeddingGradCUDAFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +template +struct EmbeddingSparseGradCUDAFunctor { + EmbeddingSparseGradCUDAFunctor(const Context& dev_ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_grad_(out_grad), + padding_idx_(padding_idx), + weight_grad_(weight_grad) {} + + template + void apply() { + // Since paddings are not trainable and fixed in forward, the gradient of + // paddings makes no sense and we don't deal with it in backward. + + const auto* ids_data = input_.template data(); + auto* d_table = weight_grad_; + auto* table = &weight_; + auto* d_output = &out_grad_; + int64_t ids_num = input_.numel(); + dim3 threads(128, 8); + dim3 grids(8, 1); + auto stream = dev_ctx_.stream(); + paddle::framework::Vector new_rows; + new_rows.resize(ids_num); + auto gpu_place = dev_ctx_.GetPlace(); + + paddle::framework::MixVector mixv_new_rows(&new_rows); + if (!std::is_same::value) { + InputTypeConvert<<>>( + ids_data, ids_num, mixv_new_rows.MutableData(gpu_place)); + } else { + paddle::memory::Copy(gpu_place, + mixv_new_rows.CUDAMutableData(gpu_place), + gpu_place, + ids_data, + ids_num * sizeof(int64_t), + stream); + } + + mixv_new_rows.CopyToCPU(); + d_table->set_rows(new_rows); + + auto* d_table_value = d_table->mutable_value(); + d_table_value->Resize({ids_num, table->dims()[1]}); + dev_ctx_.template Alloc(d_table_value); + + auto* d_table_data = d_table_value->template data(); + auto* d_output_data = d_output->template data(); + auto d_output_dims = d_output->dims(); + auto d_output_dims_2d = + phi::flatten_to_2d(d_output_dims, d_output_dims.size() - 1); + PADDLE_ENFORCE_EQ(d_table_value->dims(), + d_output_dims_2d, + phi::errors::InvalidArgument( + "ShapeError: The shape of lookup_table@Grad and " + "output@Grad should be same. " + "But received lookup_table@Grad's shape = [%s], " + "output@Grad's shape = [%s].", + d_table_value->dims(), + d_output_dims_2d)); + paddle::memory::Copy(gpu_place, + d_table_data, + gpu_place, + d_output_data, + d_output->numel() * sizeof(T), + stream); + } + + private: + const phi::GPUContext& dev_ctx_; + const DenseTensor& input_; + const DenseTensor& weight_; + const DenseTensor& out_grad_; + int64_t padding_idx_; + SelectedRows* weight_grad_; +}; + +template +void EmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const DenseTensor& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad) { + EmbeddingSparseGradCUDAFunctor functor( + ctx, input, weight, out_grad, padding_idx, weight_grad); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding_grad, + GPU, + ALL_LAYOUT, + phi::EmbeddingGradKernel, + float, + double, + phi::dtype::float16) {} + +PD_REGISTER_KERNEL(embedding_sparse_grad, + GPU, + ALL_LAYOUT, + phi::EmbeddingSparseGradKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/embedding_kernel.cu b/paddle/phi/kernels/gpu/embedding_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..7f3a31ba544d88534d8a606fba53e017a155023c --- /dev/null +++ b/paddle/phi/kernels/gpu/embedding_kernel.cu @@ -0,0 +1,126 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/embedding_kernel.h" +#include "paddle/phi/kernels/funcs/embedding_util.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" + +namespace phi { + +template +__global__ void EmbeddingFW(T *output, + const T *table, + const IdT *ids, + const int64_t N, + const int64_t K, + const int64_t D, + const int64_t padding_idx) { + int idx = threadIdx.x; + int idy = blockIdx.x + threadIdx.y * gridDim.x; + + while (idy < K) { + auto id = static_cast(ids[idy]); + T *out = output + idy * D; + const T *tab = table + id * D; + for (int i = idx; i < D; i += blockDim.x) { + if (PaddingFlag) { + if (id == padding_idx) + out[i] = static_cast(0); + else + out[i] = tab[i]; + } else { + out[i] = tab[i]; + } + } + idy += blockDim.y * gridDim.x; + } +} + +template +struct EmbeddingCUDAFunctor { + EmbeddingCUDAFunctor(const Context &dev_ctx, + const DenseTensor &input, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) + : dev_ctx_(dev_ctx), + input_(input), + weight_(weight), + out_(out), + padding_idx_(padding_idx) {} + + template + void apply() { + size_t N = weight_.dims()[0]; + size_t D = weight_.dims()[1]; + size_t K = input_.numel(); + + const int gridx = 2 * dev_ctx_.GetSMCount(); + dim3 threads(256, 4); + dim3 grids(gridx, 1); + + const T *table = weight_.template data(); + const IdT *ids = input_.template data(); + auto *output = dev_ctx_.template Alloc(out_); + auto stream = dev_ctx_.stream(); + + if (padding_idx_ == -1) { + EmbeddingFW<<>>( + output, table, ids, N, K, D, padding_idx_); + } else { + EmbeddingFW<<>>( + output, table, ids, N, K, D, padding_idx_); + } + } + + private: + const phi::GPUContext &dev_ctx_; + const DenseTensor &input_; + const DenseTensor &weight_; + DenseTensor *out_; + int64_t padding_idx_; +}; + +template +void EmbeddingKernel(const Context &ctx, + const DenseTensor &input, + const DenseTensor &weight, + int64_t padding_idx, + DenseTensor *out) { + EmbeddingCUDAFunctor functor( + ctx, input, weight, padding_idx, out); + + if (input.dtype() == phi::DataType::INT32) { + functor.template apply(); + } else if (input.dtype() == phi::DataType::INT64) { + functor.template apply(); + } else { + PADDLE_THROW(phi::errors::Unimplemented( + "emebdding input only support int32 and int64")); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(embedding, + GPU, + ALL_LAYOUT, + phi::EmbeddingKernel, + float, + double, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..772268c2cc3889db6c328fa99425dc6996320050 --- /dev/null +++ b/paddle/phi/kernels/sparse_weight_embedding_grad_kernel.h @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SparseWeightEmbeddingGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + DenseTensor* weight_grad); + +template +void SparseWeightEmbeddingSparseGradKernel(const Context& ctx, + const DenseTensor& input, + const SelectedRows& weight, + const DenseTensor& out_grad, + int64_t padding_idx, + SelectedRows* weight_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/sparse_weight_embedding_kernel.h b/paddle/phi/kernels/sparse_weight_embedding_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..c7392b691aa0fa4f0bc28e35cc29bb6aa902c34f --- /dev/null +++ b/paddle/phi/kernels/sparse_weight_embedding_kernel.h @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void SparseWeightEmbeddingKernel(const Context& ctx, + const DenseTensor& inputx, + const SelectedRows& weight, + int64_t padding_idx, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/ops/compat/embedding_sig.cc b/paddle/phi/ops/compat/embedding_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..b79a381dcecc7943d0e82dbf122ece783cc33791 --- /dev/null +++ b/paddle/phi/ops/compat/embedding_sig.cc @@ -0,0 +1,64 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature EmbeddingOpArgumentMapping(const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("W")) { + return KernelSignature("embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } else { + return KernelSignature( + "sparse_weight_embedding", {"Ids", "W"}, {"padding_idx"}, {"Out"}); + } +} + +KernelSignature EmbeddingGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + if (ctx.IsDenseTensorInput("W")) { + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } + } else { + if ((paddle::any_cast(ctx.Attr("is_sparse"))) == true) { + return KernelSignature("sparse_weight_embedding_sparse_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } else { + return KernelSignature("sparse_weight_embedding_grad", + {"Ids", "W", GradVarName("Out")}, + {"padding_idx"}, + {GradVarName("W")}); + } + } +} + +} // namespace phi + +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2, embedding); +PD_REGISTER_BASE_KERNEL_NAME(lookup_table_v2_grad, embedding_grad); + +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2, phi::EmbeddingOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(lookup_table_v2_grad, + phi::EmbeddingGradOpArgumentMapping);