未验证 提交 7dafee77 编写于 作者: B Bo Zhou 提交者: GitHub

Save params (#107)

* new feature: save params

* add unittest for save()/retore()

* add an example demonstrating the usage

* rename the variable

* yapf

* fix comment
上级 2f11d0c5
......@@ -26,14 +26,14 @@ class CartpoleAgent(parl.Agent):
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)
with fluid.program_guard(self.train_program):
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(name='act', shape=[1], dtype='int64')
......@@ -68,5 +68,5 @@ class CartpoleAgent(parl.Agent):
'reward': reward.astype('float32')
}
cost = self.fluid_executor.run(
self.train_program, feed=feed, fetch_list=[self.cost])[0]
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
......@@ -57,6 +57,8 @@ class Agent(AgentBase):
- ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an action given current observation.
- ``learn``: update the parameters of self.alg using the `learn_program` defined in `build_program()`.
- ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path.
Todo:
- allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``.
......@@ -157,3 +159,63 @@ class Agent(AgentBase):
"""
raise NotImplementedError
def save(self, save_path, program=None):
"""Save parameters.
Args:
save_path(str): where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
def restore(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
......@@ -16,16 +16,14 @@ import numpy as np
import unittest
from paddle import fluid
from parl import layers
from parl.core.fluid.agent import Agent
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid.model import Model
from parl.utils.machine_info import get_gpu_count
import parl
import os
class TestModel(Model):
class TestModel(parl.Model):
def __init__(self):
self.fc1 = layers.fc(size=256)
self.fc2 = layers.fc(size=128)
self.fc2 = layers.fc(size=1)
def policy(self, obs):
out = self.fc1(obs)
......@@ -33,25 +31,44 @@ class TestModel(Model):
return out
class TestAlgorithm(Algorithm):
class TestAlgorithm(parl.Algorithm):
def __init__(self, model):
self.model = model
def predict(self, obs):
return self.model.policy(obs)
def learn(self, obs, label):
pred_output = self.model.policy(obs)
cost = layers.square_error_cost(obs, label)
cost = fluid.layers.reduce_mean(cost)
return cost
class TestAgent(Agent):
class TestAgent(parl.Agent):
def __init__(self, algorithm, gpu_id=None):
super(TestAgent, self).__init__(algorithm, gpu_id)
def build_program(self):
self.predict_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.algorithm.predict(obs)
self.predict_output = [output]
with fluid.program_guard(self.learn_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='float32')
cost = self.algorithm.learn(obs, label)
def learn(self, obs, label):
output_np = self.fluid_executor.run(
self.learn_program, feed={
'obs': obs,
'label': label
})
def predict(self, obs):
output_np = self.fluid_executor.run(
self.predict_program,
......@@ -66,12 +83,39 @@ class AgentBaseTest(unittest.TestCase):
self.algorithm = TestAlgorithm(self.model)
def test_agent(self):
if get_gpu_count() > 0:
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
def test_save(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
self.assertTrue(os.path.exists(save_path2))
def test_restore(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
previous_output = agent.predict(obs)
agent.save(save_path1)
agent.restore(save_path1)
current_output = agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册