未验证 提交 9ee4e3dc 编写于 作者: L LiuChiachi 提交者: GitHub

Correct 2.0 API usage in hapi.model.load (#26829)

* replace fluid.optimizer.set_dict with optimizer.set_state_dict

* replace fluid.optimizer.set_dict with optimizer.set_state_dict

* add coverage rate

* Increase coverage rate, fix code style

* Increase coverage rate, fix code style

* add fit to generate optimizer.state_dict() to save .pdopt to increase coverage rate

* delete http.log
上级 4e1c8f3e
...@@ -731,8 +731,8 @@ class DynamicGraphAdapter(object): ...@@ -731,8 +731,8 @@ class DynamicGraphAdapter(object):
if not self.model._optimizer or not optim_state: if not self.model._optimizer or not optim_state:
return return
# If optimizer performs set_dict when state vars haven't been created, # If optimizer performs set_state_dict when state vars haven't been created,
# which would happen when set_dict before minimize, the state would be # which would happen when set_state_dict before minimize, the state would be
# stored in optimizer._accumulators_holder and loaded lazily. # stored in optimizer._accumulators_holder and loaded lazily.
# To contrive this when loading from static-graph saved states, extend # To contrive this when loading from static-graph saved states, extend
# state dict to include keys named accoring to dygraph naming rules. # state dict to include keys named accoring to dygraph naming rules.
...@@ -776,7 +776,13 @@ class DynamicGraphAdapter(object): ...@@ -776,7 +776,13 @@ class DynamicGraphAdapter(object):
accum_name + "_0") accum_name + "_0")
converted_state[dy_state_name] = state_var converted_state[dy_state_name] = state_var
self.model._optimizer.set_dict(converted_state) if not hasattr(self.model._optimizer, 'set_state_dict'):
warnings.warn(
"paddle.fluid.optimizer is deprecated in API 2.0, please use paddle.optimizer instead"
)
self.model._optimizer.set_dict(converted_state)
else:
self.model._optimizer.set_state_dict(converted_state)
class Model(object): class Model(object):
......
...@@ -416,6 +416,29 @@ class TestModelFunction(unittest.TestCase): ...@@ -416,6 +416,29 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(path) shutil.rmtree(path)
fluid.disable_dygraph() if dynamic else None fluid.disable_dygraph() if dynamic else None
def test_dynamic_load(self):
mnist_data = MnistDataset(mode='train')
for new_optimizer in [True, False]:
path = tempfile.mkdtemp()
paddle.disable_static()
net = LeNet()
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
if new_optimizer:
optim = paddle.optimizer.Adam(
learning_rate=0.001, parameters=net.parameters())
else:
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=net.parameters())
model = Model(net, inputs, labels)
model.prepare(
optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.fit(mnist_data, batch_size=64, verbose=0)
model.save(path + '/test')
model.load(path + '/test')
shutil.rmtree(path)
paddle.enable_static()
def test_dynamic_save_static_load(self): def test_dynamic_save_static_load(self):
path = tempfile.mkdtemp() path = tempfile.mkdtemp()
# dynamic saving # dynamic saving
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册