未验证 提交 12311ddc 编写于 作者: W Weilong Wu 提交者: GitHub

[Eager] Fix final state adam in selected rows case (#42219)

* [Eager] Support final_state_adam when argument grad (position 1) is selected_rows

* Remove needless code

* Add adam_dense_param_sparse_grad kernel
上级 9dadf7df
......@@ -69,7 +69,12 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
kernel_data_type = kernel_key.dtype();
}
}
std::string kernel_name = "adam";
if (!phi::DenseTensor::classof(grad.impl().get())) {
kernel_name = "adam_dense_param_sparse_grad";
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
kernel_name, {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << kernel_name << " API kernel key: [" << kernel_backend << ", "
......@@ -77,9 +82,7 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
VLOG(6) << kernel_name << " API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_param = PrepareData(param, kernel.InputAt(0), {});
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
auto input_lr = PrepareData(learning_rate, kernel.InputAt(2), {});
auto input_moment1 = PrepareData(moment1, kernel.InputAt(3), {});
auto input_moment2 = PrepareData(moment2, kernel.InputAt(4), {});
......@@ -140,6 +143,9 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
if (phi::DenseTensor::classof(grad.impl().get())) {
auto input_grad = PrepareData(grad, kernel.InputAt(1), {});
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
......@@ -211,7 +217,81 @@ std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> adam_impl(
kernel_out_3,
kernel_out_4,
kernel_out_5);
} else {
auto input_grad = TensorToSelectedRows(grad);
phi::AdamInferMeta(MakeMetaTensor(*input_param),
MakeMetaTensor(*input_grad),
MakeMetaTensor(*input_lr),
MakeMetaTensor(*input_moment1),
MakeMetaTensor(*input_moment2),
MakeMetaTensor(*input_beta1_pow),
MakeMetaTensor(*input_beta2_pow),
input_meta_ref_master_param,
input_meta_ref_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::SelectedRows&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
paddle::optional<const phi::DenseTensor&>,
paddle::optional<const phi::DenseTensor&>,
const Scalar&,
const Scalar&,
const Scalar&,
bool,
int64_t,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*input_param,
*input_grad,
*input_lr,
*input_moment1,
*input_moment2,
*input_beta1_pow,
*input_beta2_pow,
input_master_param,
input_skip_update,
beta1,
beta2,
epsilon,
lazy_mode,
min_row_size_to_use_multithread,
multi_precision,
use_global_beta_pow,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
}
return api_output;
}
......
......@@ -19,14 +19,13 @@ import unittest
import paddle
import paddle.nn as nn
import numpy as np
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
from paddle.fluid.framework import _test_eager_guard
paddle.disable_static()
class EmbeddingDygraph(unittest.TestCase):
def test_1(self):
def func_1(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
paddle.disable_static(paddle.CPUPlace())
x = paddle.to_tensor(x_data, stop_gradient=False)
......@@ -44,7 +43,12 @@ class EmbeddingDygraph(unittest.TestCase):
out.backward()
adam.step()
def test_2(self):
def test_1(self):
with _test_eager_guard():
self.func_1()
self.func_1()
def func_2(self):
x_data = np.arange(3, 6).reshape((3, 1)).astype(np.int64)
y_data = np.arange(6, 12).reshape((3, 2)).astype(np.float32)
paddle.disable_static(paddle.CPUPlace())
......@@ -60,6 +64,11 @@ class EmbeddingDygraph(unittest.TestCase):
with self.assertRaises(ValueError):
embedding = paddle.nn.Embedding(10, -3, sparse=True)
def test_2(self):
with _test_eager_guard():
self.func_2()
self.func_2()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册