未验证 提交 92081e1d 编写于 作者: L Leo Chen 提交者: GitHub

fix undefined variable in optimizer (#33416)

上级 2af23549
......@@ -257,11 +257,11 @@ class Optimizer(object):
assert model_np.shape == load_para_np.shape, \
"Parameter shape not match, Dygraph Parameter [ {} ] need tensor with shape {} but load tensor with shape {}".format(
item.name, model_np.shape, load_para_np.shape)
param.name, model_np.shape, load_para_np.shape)
assert model_np.dtype == load_para_np.dtype, \
"Parameter dtype not match, Dygraph Parameter [ {} ] need tensor with dtype {} but load tensor with dtype {}".format(
item.name, model_np.dtype, load_para_np.dtype)
param.name, model_np.dtype, load_para_np.dtype)
tensor.set(load_para_np, framework._current_expected_place())
......
......@@ -804,11 +804,48 @@ class TestNetWithEpsilonTensor(unittest.TestCase):
adam.minimize(b)
state_dict = adam.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph("paddle_dy")
adam.set_state_dict(opti_state_dict)
para_state_dict, opt_state_dict = fluid.load_dygraph("paddle_dy")
adam.set_state_dict(opt_state_dict)
paddle.enable_static()
def test_adam_save_load_error(self):
paddle.disable_static()
def get_opt(dtype, shape):
with paddle.utils.unique_name.guard():
paddle.set_default_dtype(dtype)
a = paddle.rand([4, 10])
linear = paddle.nn.Linear(10, 10)
b = linear(a)
state_dict = linear.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
scheduler = paddle.optimizer.lr.NoamDecay(
d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.fluid.optimizer.Adam(
learning_rate=scheduler,
parameter_list=linear.parameters(),
use_global_beta_pow=True)
adam.minimize(b)
return adam
adam = get_opt('float32', [10, 10])
state_dict = adam.state_dict()
fluid.save_dygraph(state_dict, "paddle_dy")
para_state_dict, opt_state_dict = fluid.load_dygraph("paddle_dy")
adam.set_state_dict(opt_state_dict)
adam2 = get_opt('float64', [10, 10]) # dtype not match
self.assertRaises(AssertionError, adam2.set_state_dict, opt_state_dict)
adam3 = get_opt('float32', [10, 10]) # shape not match
opt_state_dict['beta1_pow_acc_0'] = np.array(
[0.9, 0.9], dtype='float32')
self.assertRaises(AssertionError, adam3.set_state_dict, opt_state_dict)
paddle.enable_static()
class TestAdamOpV2Group(TestAdamOpV2):
def test_adam_op(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册