From 2785f8762ed24316b71e9ae0dab4a639b01b19fe Mon Sep 17 00:00:00 2001 From: zyfncg Date: Tue, 31 May 2022 11:40:02 +0800 Subject: [PATCH] add embedding yaml (#43029) * add embedding yaml * fix infermeta bug * fix bug of selected_rows infer_meta * fix selected_rows * add unittest --- paddle/phi/api/lib/api_custom_impl.cc | 193 ++++++++++++++++++ paddle/phi/api/lib/api_custom_impl.h | 12 ++ paddle/phi/infermeta/binary.cc | 26 +++ paddle/phi/infermeta/binary.h | 6 + paddle/phi/tests/api/CMakeLists.txt | 1 + paddle/phi/tests/api/test_embedding_api.cc | 119 +++++++++++ .../unittests/test_lookup_table_v2_op.py | 5 +- python/paddle/nn/functional/input.py | 4 +- python/paddle/utils/code_gen/api.yaml | 6 + python/paddle/utils/code_gen/backward.yaml | 6 + 10 files changed, 375 insertions(+), 3 deletions(-) create mode 100644 paddle/phi/tests/api/test_embedding_api.cc diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index b6431fcbe6..14746abf59 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -638,6 +638,80 @@ Tensor copy_to_impl(const Tensor& x, Place place, bool blocking) { return out; } +Tensor embedding_impl(const Tensor& x, + const Tensor& weight, + int64_t padding_idx, + bool sparse) { + DataType kernel_data_type = ParseDataType(weight); + auto kernel_key_set = ParseKernelKeyByInputArgs(weight); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + VLOG(6) << "embedding API kernel key: [" << kernel_key.backend() << ", " + << kernel_key.layout() << ", " << kernel_data_type << "]"; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + + Tensor api_output; + + if (phi::DenseTensor::classof(weight.impl().get())) { + const auto& kernel = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + "embedding", + {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + VLOG(6) << "embedding API kernel: " << kernel; + + auto input_x = PrepareData(x, kernel.InputAt(0), {}); + auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); + + auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output); + phi::MetaTensor meta_out(kernel_out); + + phi::EmbeddingInferMeta(MakeMetaTensor(*input_x), + MakeMetaTensor(*input_weight), + padding_idx, + sparse, + &meta_out); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + int64_t, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + { + (*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out); + } + } else { + const auto& kernel = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + "sparse_weight_embedding", + {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + VLOG(6) << "sparse_weight_embedding API kernel: " << kernel; + + auto input_x = PrepareData(x, kernel.InputAt(0), {}); + auto input_weight = TensorToSelectedRows(weight); + + auto* kernel_out = SetKernelOutput(kernel_key.backend(), &api_output); + phi::MetaTensor meta_out(kernel_out); + + phi::EmbeddingInferMeta(MakeMetaTensor(*input_x), + MakeMetaTensor(*input_weight), + padding_idx, + sparse, + &meta_out); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::SelectedRows&, + int64_t, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + { + (*kernel_fn)(*dev_ctx, *input_x, *input_weight, padding_idx, kernel_out); + } + } + return api_output; +} + std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis) { @@ -1176,6 +1250,125 @@ void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad) { (*kernel_fn)(*dev_ctx, *dense_out_grad, kernel_out); } +void embedding_grad_impl(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx, + bool sparse, + Tensor* weight_grad) { + DataType kernel_data_type = ParseDataType(weight); + auto kernel_key_set = ParseKernelKeyByInputArgs(weight); + auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey(); + VLOG(6) << "embedding_grad API kernel key: [" << kernel_key.backend() << ", " + << kernel_key.layout() << ", " << kernel_data_type << "]"; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); + + if (phi::DenseTensor::classof(weight.impl().get())) { + std::string kernel_name = + sparse ? "embedding_sparse_grad" : "embedding_grad"; + const auto& kernel = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, + {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + VLOG(6) << kernel_name << " API kernel: " << kernel; + + auto input_x = PrepareData(x, kernel.InputAt(0), {}); + auto input_weight = PrepareData(weight, kernel.InputAt(1), {}); + auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); + + if (sparse) { + auto* kernel_out = + SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad); + phi::MetaTensor meta_out(kernel_out); + meta_out.set_dims(input_weight->dims()); + meta_out.set_dtype(input_weight->dtype()); + kernel_out->set_height(input_weight->dims()[0]); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + int64_t, + phi::SelectedRows*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *input_x, + *input_weight, + *input_out_grad, + padding_idx, + kernel_out); + } else { + auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad); + phi::MetaTensor meta_out(kernel_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out); + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + int64_t, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *input_x, + *input_weight, + *input_out_grad, + padding_idx, + kernel_out); + } + } else { + std::string kernel_name = sparse ? "sparse_weight_embedding_sparse_grad" + : "sparse_weight_embedding_grad"; + const auto& kernel = + phi::KernelFactory::Instance().SelectKernelOrThrowError( + kernel_name, + {kernel_key.backend(), kernel_key.layout(), kernel_data_type}); + VLOG(6) << kernel_name << " API kernel: " << kernel; + + auto input_x = PrepareData(x, kernel.InputAt(0), {}); + auto input_weight = TensorToSelectedRows(weight); + auto input_out_grad = PrepareData(out_grad, kernel.InputAt(2), {}); + + if (sparse) { + auto* kernel_out = + SetSelectedRowsKernelOutput(kernel_key.backend(), weight_grad); + phi::MetaTensor meta_out(kernel_out); + phi::UnchangedInferMeta(MakeMetaTensor(*input_weight), &meta_out); + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::SelectedRows&, + const phi::DenseTensor&, + int64_t, + phi::SelectedRows*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *input_x, + *input_weight, + *input_out_grad, + padding_idx, + kernel_out); + } else { + auto* kernel_out = SetKernelOutput(kernel_key.backend(), weight_grad); + phi::MetaTensor meta_out(kernel_out); + meta_out.set_dims(input_weight->GetCompleteDims()); + meta_out.set_dtype(input_weight->dtype()); + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::SelectedRows&, + const phi::DenseTensor&, + int64_t, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *input_x, + *input_weight, + *input_out_grad, + padding_idx, + kernel_out); + } + } +} + void real_grad_impl(const Tensor& out_grad, Tensor* x_grad) { phi::KernelKey kernel_key{ParseBackend(out_grad), out_grad.layout(), diff --git a/paddle/phi/api/lib/api_custom_impl.h b/paddle/phi/api/lib/api_custom_impl.h index f8ccbb36c5..f700345f46 100644 --- a/paddle/phi/api/lib/api_custom_impl.h +++ b/paddle/phi/api/lib/api_custom_impl.h @@ -98,6 +98,11 @@ Tensor conv2d_impl(const Tensor& input, Tensor copy_to_impl(const Tensor& x, Place place, bool blocking); +Tensor embedding_impl(const Tensor& x, + const Tensor& weight, + int64_t padding_idx, + bool sparse); + std::vector split_impl(const Tensor& x, const IntArray& num_or_sections, const Scalar& axis); @@ -145,6 +150,13 @@ void conv2d_grad_impl(const Tensor& input, void imag_grad_impl(const Tensor& out_grad, Tensor* x_grad); +void embedding_grad_impl(const Tensor& x, + const Tensor& weight, + const Tensor& out_grad, + int64_t padding_idx, + bool sparse, + Tensor* weight_grad); + void real_grad_impl(const Tensor& out_grad, Tensor* x_grad); } // namespace experimental diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index 76b6fcdd52..a8d5ad564f 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -983,6 +983,32 @@ void ElementwiseRawInferMeta(const MetaTensor& x, out->share_lod(x); } +void EmbeddingInferMeta(const MetaTensor& x, + const MetaTensor& weight, + int64_t padding_idx, + bool sparse, + MetaTensor* out) { + const auto& table_dims = weight.dims(); + const auto& ids_dims = x.dims(); + int ids_rank = ids_dims.size(); + VLOG(5) << "ids rank is " << ids_rank << std::endl; + PADDLE_ENFORCE_EQ( + table_dims.size(), + 2, + phi::errors::InvalidArgument( + "ShapeError: The dimensions of the 'lookup table' must be 2. " + "But received lookup table's dimensions = %d, " + "lookup table's shape = [%s].", + table_dims.size(), + table_dims)); + + auto output_dims = phi::vectorize(ids_dims); + output_dims.push_back(table_dims[1]); + out->set_dims(phi::make_ddim(output_dims)); + out->set_dtype(weight.dtype()); + out->share_lod(x); +} + void ExpandAsInferMeta(const MetaTensor& x, const MetaTensor& y, const std::vector& target_shape, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 0c86e5389c..2cd34406fc 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -154,6 +154,12 @@ void ElementwiseRawInferMeta(const MetaTensor& x_meta, int axis, MetaTensor* out); +void EmbeddingInferMeta(const MetaTensor& x, + const MetaTensor& weight, + int64_t padding_idx, + bool sparse, + MetaTensor* out); + void ExpandAsInferMeta(const MetaTensor& x, const MetaTensor& y, const std::vector& target_shape, diff --git a/paddle/phi/tests/api/CMakeLists.txt b/paddle/phi/tests/api/CMakeLists.txt index 5c1d098962..2333f82d62 100644 --- a/paddle/phi/tests/api/CMakeLists.txt +++ b/paddle/phi/tests/api/CMakeLists.txt @@ -15,6 +15,7 @@ cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS} api_scalar) cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS}) +cc_test(test_embedding_api SRCS test_embedding_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS}) cc_test(test_to_api SRCS test_to_api.cc DEPS ${COMMON_API_TEST_DEPS}) diff --git a/paddle/phi/tests/api/test_embedding_api.cc b/paddle/phi/tests/api/test_embedding_api.cc new file mode 100644 index 0000000000..6ccd382786 --- /dev/null +++ b/paddle/phi/tests/api/test_embedding_api.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/phi/api/backward/backward_api.h" +#include "paddle/phi/api/include/api.h" + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +PD_DECLARE_KERNEL(sparse_weight_embedding, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(sparse_weight_embedding_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(sparse_weight_embedding_sparse_grad, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(empty, CPU, ALL_LAYOUT); +PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); + +namespace paddle { +namespace tests { + +TEST(API, sparse_weight_embedding) { + auto x = paddle::experimental::empty({4}, DataType::INT32); + auto* x_data = x.data(); + x_data[0] = 0; + x_data[1] = 4; + x_data[2] = 3; + x_data[3] = 1; + + auto weight_sr = std::make_shared( + std::vector{0, 1, 2, 3, 4, 5, 6}, 16); + *weight_sr->mutable_value() = *static_cast( + paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get()); + paddle::experimental::Tensor weight; + weight.set_impl(weight_sr); + + auto out = paddle::experimental::embedding(x, weight); + + // 3. check result + ASSERT_EQ(out.dims().size(), 2); + ASSERT_EQ(out.dims()[0], 4); + ASSERT_EQ(out.numel(), 12); + ASSERT_EQ(out.type(), phi::DataType::FLOAT32); + ASSERT_EQ(out.layout(), phi::DataLayout::NCHW); +} + +TEST(API, sparse_weight_embedding_grad) { + auto x = paddle::experimental::empty({4}, DataType::INT32); + auto* x_data = x.data(); + x_data[0] = 0; + x_data[1] = 4; + x_data[2] = 3; + x_data[3] = 1; + + auto weight_sr = std::make_shared( + std::vector{0, 1, 2, 3, 4, 5, 6}, 16); + *weight_sr->mutable_value() = *static_cast( + paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get()); + paddle::experimental::Tensor weight; + weight.set_impl(weight_sr); + + auto out_grad = paddle::experimental::full({4, 3}, 1, DataType::FLOAT32); + + paddle::experimental::Tensor weight_grad; + + paddle::experimental::embedding_grad( + x, weight, out_grad, -1, false, &weight_grad); + + // 3. check result + ASSERT_EQ(weight_grad.dims().size(), 2); + ASSERT_EQ(weight_grad.dims()[0], 16); + ASSERT_EQ(weight_grad.numel(), 48); + ASSERT_EQ(weight_grad.type(), phi::DataType::FLOAT32); + ASSERT_EQ(weight_grad.layout(), phi::DataLayout::NCHW); +} + +TEST(API, sparse_weight_embedding_sparse_grad) { + auto x = paddle::experimental::empty({4}, DataType::INT32); + auto* x_data = x.data(); + x_data[0] = 0; + x_data[1] = 4; + x_data[2] = 3; + x_data[3] = 1; + + auto weight_sr = std::make_shared( + std::vector{0, 1, 2, 3, 4, 5, 6}, 16); + *weight_sr->mutable_value() = *static_cast( + paddle::experimental::full({7, 3}, 2, DataType::FLOAT32).impl().get()); + paddle::experimental::Tensor weight; + weight.set_impl(weight_sr); + + auto out_grad = paddle::experimental::full({4, 3}, 1, DataType::FLOAT32); + + paddle::experimental::Tensor weight_grad; + + paddle::experimental::embedding_grad( + x, weight, out_grad, -1, true, &weight_grad); + + // 3. check result + ASSERT_EQ(weight_grad.dims().size(), 2); + ASSERT_EQ(weight_grad.dims()[0], 4); + ASSERT_EQ(weight_grad.numel(), 12); + ASSERT_EQ(weight_grad.type(), phi::DataType::FLOAT32); + ASSERT_EQ(weight_grad.layout(), phi::DataLayout::NCHW); +} + +} // namespace tests +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py index cad6437d1d..21844c9e40 100644 --- a/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_lookup_table_v2_op.py @@ -48,6 +48,7 @@ class TestStaticGraphSupportMultipleInt(unittest.TestCase): class TestLookupTableOp(OpTest): def setUp(self): self.op_type = "lookup_table_v2" + self.python_api = paddle.nn.functional.embedding table = np.random.random((17, 31)).astype("float64") ids = np.random.randint(0, 17, 4).astype(self.id_dtype()) self.inputs = {'W': table, 'Ids': ids} @@ -57,10 +58,10 @@ class TestLookupTableOp(OpTest): return "int64" def test_check_output(self): - self.check_output() + self.check_output(check_eager=True) def test_check_grad(self): - self.check_grad(['W'], 'Out', no_grad_set=set('Ids')) + self.check_grad(['W'], 'Out', no_grad_set=set('Ids'), check_eager=True) class TestLookupTableOpInt16(OpTest): diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index cfbf015ffa..92b3a7054d 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -200,7 +200,9 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): raise ValueError("padding_idx must be within [-{}, {})".format( weight.shape[0], weight.shape[0])) - if in_dynamic_mode(): + if in_dygraph_mode(): + return _C_ops.final_state_embedding(x, weight, padding_idx, sparse) + elif _in_legacy_dygraph(): return _C_ops.lookup_table_v2( weight, x, 'is_sparse', sparse, 'is_distributed', False, 'remote_prefetch', False, 'padding_idx', padding_idx) diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index c541891662..c3a8e68ca7 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -613,6 +613,12 @@ func : elu backward : elu_grad +- api : embedding + args : (Tensor x, Tensor weight, int64_t padding_idx=-1, bool sparse=false) + output : Tensor + invoke : embedding_impl(x, weight, padding_idx, sparse) + backward : embedding_grad + - api : empty args : (IntArray shape, DataType dtype=DataType::FLOAT32, Place place=CPUPlace()) output: Tensor diff --git a/python/paddle/utils/code_gen/backward.yaml b/python/paddle/utils/code_gen/backward.yaml index b27c3aab6b..7183d822e1 100644 --- a/python/paddle/utils/code_gen/backward.yaml +++ b/python/paddle/utils/code_gen/backward.yaml @@ -609,6 +609,12 @@ backward : elu_double_grad inplace : (out_grad -> x_grad) +- backward_api : embedding_grad + forward : embedding (Tensor x, Tensor weight, int64_t padding_idx=-1, bool sparse=false) -> Tensor(out) + args : (Tensor x, Tensor weight, Tensor out_grad, int64_t padding_idx=-1, bool sparse=false) + output : Tensor(weight_grad) + invoke : embedding_grad_impl(x, weight, out_grad, padding_idx, sparse, weight_grad) + - backward_api : erf_grad forward : erf (Tensor x) -> Tensor(out) args : (Tensor x, Tensor out_grad) -- GitLab