fix_bug_in_check_lamb_warmup_step

上级 fb7e4eac
......@@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
......@@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)
......
......@@ -14,6 +14,7 @@
# ============================================================================
""" test lamb """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
......@@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x
def test_lamb_1():
""" test_Lamb_1 """
def test_lamb_compile():
""" test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lamb_2():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
def test_lamb_error():
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=1.0)
with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
新手
引导
客服 返回
顶部