未验证 提交 409eff7d 编写于 作者: M MRXLT 提交者: GitHub

Cherry adam errorinfo (#27169)

* add check for sparse parameters with weight_decay

* move sparse check to adam.py
上级 834face8
......@@ -448,7 +448,6 @@ class TestAdamOpV2(unittest.TestCase):
def test_adam_op_with_state_dict(self):
import paddle
paddle.disable_static()
emb = paddle.nn.Embedding(10, 10)
......@@ -517,6 +516,20 @@ class TestAdamOpV2(unittest.TestCase):
adam = paddle.optimizer.Adam(
0.1, epsilon=-1, parameters=linear.parameters())
def test_adam_op_with_sparse_input_and_weight_decay(self):
paddle.disable_static()
x_data = np.arange(0, 10).reshape((10, 1)).astype(np.int64)
x = paddle.to_tensor(x_data, stop_gradient=False)
emb = paddle.nn.Embedding(10, 10, sparse=True)
adam = paddle.optimizer.Adam(
0.001, parameters=emb.parameters(), weight_decay=0.01)
with self.assertRaises(RuntimeError):
out = emb(x)
out.backward()
adam.step()
if __name__ == "__main__":
unittest.main()
......@@ -250,3 +250,47 @@ class Adam(Optimizer):
stop_gradient=True)
return adam_op
@framework.dygraph_only
def step(self):
"""
Execute the optimizer and update parameters once.
Returns:
None
Examples:
.. code-block:: python
import paddle
import numpy as np
paddle.disable_static()
value = np.arange(26).reshape(2, 13).astype("float32")
a = paddle.to_tensor(value)
linear = paddle.nn.Linear(13, 5)
# This can be any optimizer supported by dygraph.
adam = paddle.optimizer.Adam(learning_rate = 0.01,
parameters = linear.parameters())
out = linear(a)
out.backward()
adam.step()
adam.clear_grad()
"""
parameter_list = self._parameter_list
self._dtype = None
params_grads = []
for param in self._parameter_list:
if not param.trainable:
continue
if hasattr(
param, "_is_sparse"
) and param._is_sparse and self.regularization is not None:
raise RuntimeError(
"Adam don't support weight_decay with sparse parameters, please set it to None."
)
if param._grad_ivar() is not None:
grad_var = param._grad_ivar()
params_grads.append((param, grad_var))
optimize_ops = self._apply_optimize(
loss=None, startup_program=None, params_grads=params_grads)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册