提交 2a6a3e01 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1555 fix bug in lamb warmup step check

Merge pull request !1555 from wangnan39/fix_bug_in_check_lamb_warmup_step
...@@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para ...@@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
...@@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate, ...@@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name) validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" test lamb """ """ test lamb """
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
...@@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell): ...@@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x return x
def test_lamb_1(): def test_lamb_compile():
""" test_Lamb_1 """ """ test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lamb_2(): def test_lamb_error():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() with pytest.raises(TypeError):
loss = nn.SoftmaxCrossEntropyWithLogits() Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
net_with_loss = WithLossCell(net, loss) with pytest.raises(TypeError):
train_network = TrainOneStepCell(net_with_loss, optimizer) Lamb(net.get_parameters(), decay_steps=1.0)
_executor.compile(train_network, inputs, label)
with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部