From 12311ddc0dc2e0db2a7e7d5e8f086b345d501c5d Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Tue, 26 Apr 2022 19:49:51 +0800 Subject: [PATCH] [Eager] Fix final state adam in selected rows case (#42219) * [Eager] Support final_state_adam when argument grad (position 1) is selected_rows * Remove needless code * Add adam_dense_param_sparse_grad kernel --- paddle/phi/api/lib/api_custom_impl.cc | 224 ++++++++++++------ .../test_nn_functional_embedding_dygraph.py | 17 +- 2 files changed, 165 insertions(+), 76 deletions(-) diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index ae248a7bf12..38a60ab9789 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -69,7 +69,12 @@ std::tuple adam_impl( kernel_data_type = kernel_key.dtype(); } } + std::string kernel_name = "adam"; + if (!phi::DenseTensor::classof(grad.impl().get())) { + kernel_name = "adam_dense_param_sparse_grad"; + } + const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError( kernel_name, {kernel_backend, kernel_layout, kernel_data_type}); VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", " @@ -77,9 +82,7 @@ std::tuple adam_impl( VLOG(6) << kernel_name << " API kernel: " << kernel; auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); - auto input_param = PrepareData(param, kernel.InputAt(0), {}); - auto input_grad = PrepareData(grad, kernel.InputAt(1), {}); auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {}); auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {}); auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {}); @@ -140,78 +143,155 @@ std::tuple adam_impl( phi::MetaTensor meta_out_4(kernel_out_4); phi::MetaTensor meta_out_5(kernel_out_5); - phi::AdamInferMeta(MakeMetaTensor(*input_param), - MakeMetaTensor(*input_grad), - MakeMetaTensor(*input_lr), - MakeMetaTensor(*input_moment1), - MakeMetaTensor(*input_moment2), - MakeMetaTensor(*input_beta1_pow), - MakeMetaTensor(*input_beta2_pow), - input_meta_ref_master_param, - input_meta_ref_skip_update, - beta1, - beta2, - epsilon, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - use_global_beta_pow, - &meta_out_0, - &meta_out_1, - &meta_out_2, - &meta_out_3, - &meta_out_4, - &meta_out_5); - - using kernel_signature = void (*)(const platform::DeviceContext&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - const phi::DenseTensor&, - paddle::optional, - paddle::optional, - const Scalar&, - const Scalar&, - const Scalar&, - bool, - int64_t, - bool, - bool, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*, - phi::DenseTensor*); - auto* kernel_fn = kernel.GetVariadicKernelFn(); + if (phi::DenseTensor::classof(grad.impl().get())) { + auto input_grad = PrepareData(grad, kernel.InputAt(1), {}); + + phi::AdamInferMeta(MakeMetaTensor(*input_param), + MakeMetaTensor(*input_grad), + MakeMetaTensor(*input_lr), + MakeMetaTensor(*input_moment1), + MakeMetaTensor(*input_moment2), + MakeMetaTensor(*input_beta1_pow), + MakeMetaTensor(*input_beta2_pow), + input_meta_ref_master_param, + input_meta_ref_skip_update, + beta1, + beta2, + epsilon, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + &meta_out_0, + &meta_out_1, + &meta_out_2, + &meta_out_3, + &meta_out_4, + &meta_out_5); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + paddle::optional, + paddle::optional, + const Scalar&, + const Scalar&, + const Scalar&, + bool, + int64_t, + bool, + bool, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); - (*kernel_fn)(*dev_ctx, - *input_param, - *input_grad, - *input_lr, - *input_moment1, - *input_moment2, - *input_beta1_pow, - *input_beta2_pow, - input_master_param, - input_skip_update, - beta1, - beta2, - epsilon, - lazy_mode, - min_row_size_to_use_multithread, - multi_precision, - use_global_beta_pow, - kernel_out_0, - kernel_out_1, - kernel_out_2, - kernel_out_3, - kernel_out_4, - kernel_out_5); + (*kernel_fn)(*dev_ctx, + *input_param, + *input_grad, + *input_lr, + *input_moment1, + *input_moment2, + *input_beta1_pow, + *input_beta2_pow, + input_master_param, + input_skip_update, + beta1, + beta2, + epsilon, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + kernel_out_0, + kernel_out_1, + kernel_out_2, + kernel_out_3, + kernel_out_4, + kernel_out_5); + } else { + auto input_grad = TensorToSelectedRows(grad); + + phi::AdamInferMeta(MakeMetaTensor(*input_param), + MakeMetaTensor(*input_grad), + MakeMetaTensor(*input_lr), + MakeMetaTensor(*input_moment1), + MakeMetaTensor(*input_moment2), + MakeMetaTensor(*input_beta1_pow), + MakeMetaTensor(*input_beta2_pow), + input_meta_ref_master_param, + input_meta_ref_skip_update, + beta1, + beta2, + epsilon, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + &meta_out_0, + &meta_out_1, + &meta_out_2, + &meta_out_3, + &meta_out_4, + &meta_out_5); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const phi::DenseTensor&, + const phi::SelectedRows&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + const phi::DenseTensor&, + paddle::optional, + paddle::optional, + const Scalar&, + const Scalar&, + const Scalar&, + bool, + int64_t, + bool, + bool, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *input_param, + *input_grad, + *input_lr, + *input_moment1, + *input_moment2, + *input_beta1_pow, + *input_beta2_pow, + input_master_param, + input_skip_update, + beta1, + beta2, + epsilon, + lazy_mode, + min_row_size_to_use_multithread, + multi_precision, + use_global_beta_pow, + kernel_out_0, + kernel_out_1, + kernel_out_2, + kernel_out_3, + kernel_out_4, + kernel_out_5); + } return api_output; } diff --git a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py index e50424126e5..0b5493e2170 100644 --- a/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py +++ b/python/paddle/fluid/tests/unittests/test_nn_functional_embedding_dygraph.py @@ -19,14 +19,13 @@ import unittest import paddle import paddle.nn as nn import numpy as np -from paddle.fluid.framework import _enable_legacy_dygraph -_enable_legacy_dygraph() +from paddle.fluid.framework import _test_eager_guard paddle.disable_static() class EmbeddingDygraph(unittest.TestCase): - def test_1(self): + def func_1(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) paddle.disable_static(paddle.CPUPlace()) x = paddle.to_tensor(x_data, stop_gradient=False) @@ -44,7 +43,12 @@ class EmbeddingDygraph(unittest.TestCase): out.backward() adam.step() - def test_2(self): + def test_1(self): + with _test_eager_guard(): + self.func_1() + self.func_1() + + def func_2(self): x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64) y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32) paddle.disable_static(paddle.CPUPlace()) @@ -60,6 +64,11 @@ class EmbeddingDygraph(unittest.TestCase): with self.assertRaises(ValueError): embedding = paddle.nn.Embedding(10, -3, sparse=True) + def test_2(self): + with _test_eager_guard(): + self.func_2() + self.func_2() + if __name__ == '__main__': unittest.main() -- GitLab