未验证 提交 7dafee77 编写于 作者: B Bo Zhou 提交者: GitHub

Save params (#107)

* new feature: save params

* add unittest for save()/retore()

* add an example demonstrating the usage

* rename the variable

* yapf

* fix comment
上级 2f11d0c5
......@@ -26,14 +26,14 @@ class CartpoleAgent(parl.Agent):
def build_program(self):
self.pred_program = fluid.Program()
self.train_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.pred_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)
with fluid.program_guard(self.train_program):
with fluid.program_guard(self.learn_program):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
act = layers.data(name='act', shape=[1], dtype='int64')
......@@ -68,5 +68,5 @@ class CartpoleAgent(parl.Agent):
'reward': reward.astype('float32')
}
cost = self.fluid_executor.run(
self.train_program, feed=feed, fetch_list=[self.cost])[0]
self.learn_program, feed=feed, fetch_list=[self.cost])[0]
return cost
......@@ -30,7 +30,7 @@ class Agent(AgentBase):
| `alias`: ``parl.Agent``
| `alias`: ``parl.core.fluid.agent.Agent``
| Agent is one of the three basic classes of PARL.
| Agent is one of the three basic classes of PARL.
| It is responsible for interacting with the environment and collecting data for training the policy.
| To implement a customized ``Agent``, users can:
......@@ -57,10 +57,12 @@ class Agent(AgentBase):
- ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an action given current observation.
- ``learn``: update the parameters of self.alg using the `learn_program` defined in `build_program()`.
- ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path.
Todo:
- allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``.
"""
def __init__(self, algorithm, gpu_id=None):
......@@ -90,13 +92,13 @@ class Agent(AgentBase):
self.fluid_executor.run(fluid.default_startup_program())
def build_program(self):
"""Build various programs here with the
"""Build various programs here with the
learn, predict, sample functions of the algorithm.
Note:
| Users **must** implement this function in an ``Agent``.
| This function will be called automatically in the initialization function.
To build a program, you must do the following:
a. Create a fluid program with ``fluid.program_guard()``;
b. Define data layers for feeding the data;
......@@ -112,7 +114,7 @@ class Agent(AgentBase):
obs = layers.data(
name='obs', shape=[self.obs_dim], dtype='float32')
self.act_prob = self.alg.predict(obs)
"""
raise NotImplementedError
......@@ -152,8 +154,68 @@ class Agent(AgentBase):
def sample(self, *args, **kwargs):
"""Return an action with noise when given the observation of the environment.
In general, this function is used in train process as noise is added to the action to preform exploration.
"""
raise NotImplementedError
def save(self, save_path, program=None):
"""Save parameters.
Args:
save_path(str): where to save the parameters.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.save_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
def restore(self, save_path, program=None):
"""Restore previously saved parameters.
This method requires a program that describes the network structure.
The save_path argument is typically a value previously passed to ``save_params()``.
Args:
save_path(str): path where parameters were previously saved.
program(fluid.Program): program that describes the neural network structure. If None, will use self.learn_program.
Raises:
ValueError: if program is None and self.learn_program does not exist.
Example:
.. code-block:: python
agent = AtariAgent()
agent.save('./model.ckpt')
agent.restore('./model.ckpt')
"""
if program is None:
program = self.learn_program
dirname = '/'.join(save_path.split('/')[:-1])
filename = save_path.split('/')[-1]
fluid.io.load_params(
executor=self.fluid_executor,
dirname=dirname,
main_program=program,
filename=filename)
......@@ -16,16 +16,14 @@ import numpy as np
import unittest
from paddle import fluid
from parl import layers
from parl.core.fluid.agent import Agent
from parl.core.fluid.algorithm import Algorithm
from parl.core.fluid.model import Model
from parl.utils.machine_info import get_gpu_count
import parl
import os
class TestModel(Model):
class TestModel(parl.Model):
def __init__(self):
self.fc1 = layers.fc(size=256)
self.fc2 = layers.fc(size=128)
self.fc2 = layers.fc(size=1)
def policy(self, obs):
out = self.fc1(obs)
......@@ -33,25 +31,44 @@ class TestModel(Model):
return out
class TestAlgorithm(Algorithm):
class TestAlgorithm(parl.Algorithm):
def __init__(self, model):
self.model = model
def predict(self, obs):
return self.model.policy(obs)
def learn(self, obs, label):
pred_output = self.model.policy(obs)
cost = layers.square_error_cost(obs, label)
cost = fluid.layers.reduce_mean(cost)
return cost
class TestAgent(Agent):
class TestAgent(parl.Agent):
def __init__(self, algorithm, gpu_id=None):
super(TestAgent, self).__init__(algorithm, gpu_id)
def build_program(self):
self.predict_program = fluid.Program()
self.learn_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.algorithm.predict(obs)
self.predict_output = [output]
with fluid.program_guard(self.learn_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='float32')
cost = self.algorithm.learn(obs, label)
def learn(self, obs, label):
output_np = self.fluid_executor.run(
self.learn_program, feed={
'obs': obs,
'label': label
})
def predict(self, obs):
output_np = self.fluid_executor.run(
self.predict_program,
......@@ -66,11 +83,38 @@ class AgentBaseTest(unittest.TestCase):
self.algorithm = TestAlgorithm(self.model)
def test_agent(self):
if get_gpu_count() > 0:
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
def test_save(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
save_path2 = './my_model/model-2.ckpt'
agent.save(save_path1)
agent.save(save_path2)
self.assertTrue(os.path.exists(save_path1))
self.assertTrue(os.path.exists(save_path2))
def test_restore(self):
agent = TestAgent(self.algorithm)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = './model.ckpt'
previous_output = agent.predict(obs)
agent.save(save_path1)
agent.restore(save_path1)
current_output = agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册