未验证 提交 a50793e4 编写于 作者: R rical730 提交者: GitHub

fix self.alg (#325)

* fix self.alg

* torch agent initialization

* remove definition of self.alg in PPO

* replace self.algorithm with self.alg

* remove unnecessary definition of self.alg

* fix cn readme

* unittest

* yapf
上级 3b2da5d1
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
# 框架结构 # 框架结构
<img src=".github/abstractions.png" alt="abstractions" width="400"/> <img src=".github/abstractions.png" alt="abstractions" width="400"/>
PARL的目标是构建一个可以完复杂任务的智能体。以下是用户在逐步构建一个智能体的过程中需要了解到的结构: PARL的目标是构建一个可以完复杂任务的智能体。以下是用户在逐步构建一个智能体的过程中需要了解到的结构:
### Model ### Model
`Model` 用来定义前向(`Forward`)网络,这通常是一个策略网络(`Policy Network`)或者一个值函数网络(`Value Function`),输入是当前环境状态(`State`)。 `Model` 用来定义前向(`Forward`)网络,这通常是一个策略网络(`Policy Network`)或者一个值函数网络(`Value Function`),输入是当前环境状态(`State`)。
......
...@@ -90,7 +90,7 @@ class AlphaZeroAgent(parl.Agent): ...@@ -90,7 +90,7 @@ class AlphaZeroAgent(parl.Agent):
Args: Args:
examples: list of examples, each example is of form (board, pi, v) examples: list of examples, each example is of form (board, pi, v)
""" """
optimizer = optim.Adam(self.algorithm.model.parameters(), lr=args.lr) optimizer = optim.Adam(self.alg.model.parameters(), lr=args.lr)
for epoch in range(args.epochs): for epoch in range(args.epochs):
print('EPOCH ::: ' + str(epoch + 1)) print('EPOCH ::: ' + str(epoch + 1))
...@@ -111,7 +111,7 @@ class AlphaZeroAgent(parl.Agent): ...@@ -111,7 +111,7 @@ class AlphaZeroAgent(parl.Agent):
), target_pis.contiguous().cuda(), target_vs.contiguous( ), target_pis.contiguous().cuda(), target_vs.contiguous(
).cuda() ).cuda()
total_loss, pi_loss, v_loss = self.algorithm.learn( total_loss, pi_loss, v_loss = self.alg.learn(
boards, target_pis, target_vs, optimizer) boards, target_pis, target_vs, optimizer)
# record loss with tqdm # record loss with tqdm
...@@ -132,7 +132,7 @@ class AlphaZeroAgent(parl.Agent): ...@@ -132,7 +132,7 @@ class AlphaZeroAgent(parl.Agent):
board = board.contiguous().cuda() board = board.contiguous().cuda()
board = board.view(1, self.board_x, self.board_y) board = board.view(1, self.board_x, self.board_y)
pi, v = self.algorithm.predict(board) pi, v = self.alg.predict(board)
return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0] return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0]
......
...@@ -26,7 +26,7 @@ class CartpoleAgent(parl.Agent): ...@@ -26,7 +26,7 @@ class CartpoleAgent(parl.Agent):
""" """
def __init__(self, algorithm): def __init__(self, algorithm):
self.algorithm = algorithm super(CartpoleAgent, self).__init__(algorithm)
self.device = torch.device("cuda" if torch.cuda. self.device = torch.device("cuda" if torch.cuda.
is_available() else "cpu") is_available() else "cpu")
...@@ -40,7 +40,7 @@ class CartpoleAgent(parl.Agent): ...@@ -40,7 +40,7 @@ class CartpoleAgent(parl.Agent):
action(int) action(int)
""" """
obs = torch.tensor(obs, device=self.device, dtype=torch.float) obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs) prob = self.alg.predict(obs).cpu()
prob = prob.data.numpy() prob = prob.data.numpy()
action = np.random.choice(len(prob), 1, p=prob)[0] action = np.random.choice(len(prob), 1, p=prob)[0]
return action return action
...@@ -55,7 +55,7 @@ class CartpoleAgent(parl.Agent): ...@@ -55,7 +55,7 @@ class CartpoleAgent(parl.Agent):
action(int) action(int)
""" """
obs = torch.tensor(obs, device=self.device, dtype=torch.float) obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs) prob = self.alg.predict(obs)
_, action = prob.max(-1) _, action = prob.max(-1)
return action.item() return action.item()
...@@ -75,5 +75,5 @@ class CartpoleAgent(parl.Agent): ...@@ -75,5 +75,5 @@ class CartpoleAgent(parl.Agent):
action = torch.tensor(action, device=self.device, dtype=torch.long) action = torch.tensor(action, device=self.device, dtype=torch.long)
reward = torch.tensor(reward, device=self.device, dtype=torch.float) reward = torch.tensor(reward, device=self.device, dtype=torch.float)
loss = self.algorithm.learn(obs, action, reward) loss = self.alg.learn(obs, action, reward)
return loss.item() return loss.item()
...@@ -25,19 +25,19 @@ class Agent(parl.Agent): ...@@ -25,19 +25,19 @@ class Agent(parl.Agent):
self.obs_shape = config['obs_shape'] self.obs_shape = config['obs_shape']
def sample(self, obs): def sample(self, obs):
sample_actions, values = self.algorithm.sample(obs) sample_actions, values = self.alg.sample(obs)
return sample_actions, values return sample_actions, values
def predict(self, obs): def predict(self, obs):
predict_actions = self.algorithm.predict(obs) predict_actions = self.alg.predict(obs)
return predict_actions return predict_actions
def value(self, obs): def value(self, obs):
values = self.algorithm.value(obs) values = self.alg.value(obs)
return values return values
def learn(self, obs, actions, advantages, target_values): def learn(self, obs, actions, advantages, target_values):
total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.algorithm.learn( total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.alg.learn(
obs, actions, advantages, target_values) obs, actions, advantages, target_values)
return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff
...@@ -36,12 +36,12 @@ class AtariAgent(parl.Agent): ...@@ -36,12 +36,12 @@ class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim): def __init__(self, algorithm, act_dim):
assert isinstance(act_dim, int) assert isinstance(act_dim, int)
super(AtariAgent, self).__init__(algorithm)
self.act_dim = act_dim self.act_dim = act_dim
self.exploration = 1 self.exploration = 1
self.global_step = 0 self.global_step = 0
self.update_target_steps = 10000 // 4 self.update_target_steps = 10000 // 4
self.alg = algorithm
self.device = torch.device('cuda' if torch.cuda. self.device = torch.device('cuda' if torch.cuda.
is_available() else 'cpu') is_available() else 'cpu')
......
...@@ -18,7 +18,7 @@ import torch ...@@ -18,7 +18,7 @@ import torch
class MujocoAgent(parl.Agent): class MujocoAgent(parl.Agent):
def __init__(self, algorithm, device): def __init__(self, algorithm, device):
self.alg = algorithm super(MujocoAgent, self).__init__(algorithm)
self.device = device self.device = device
def predict(self, obs): def predict(self, obs):
......
...@@ -31,7 +31,6 @@ class MujocoAgent(parl.Agent): ...@@ -31,7 +31,6 @@ class MujocoAgent(parl.Agent):
policy_learn_times=20, policy_learn_times=20,
value_learn_times=10, value_learn_times=10,
value_batch_size=256): value_batch_size=256):
self.alg = algorithm
self.obs_dim = obs_dim self.obs_dim = obs_dim
self.act_dim = act_dim self.act_dim = act_dim
assert loss_type == 'CLIP' or loss_type == 'KLPEN' assert loss_type == 'CLIP' or loss_type == 'KLPEN'
......
...@@ -33,7 +33,6 @@ class DQNModel(parl.Model): ...@@ -33,7 +33,6 @@ class DQNModel(parl.Model):
class DQNAgent(parl.Agent): class DQNAgent(parl.Agent):
def __init__(self, algorithm): def __init__(self, algorithm):
super(DQNAgent, self).__init__(algorithm) super(DQNAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self): def build_program(self):
self.pred_program = fluid.Program() self.pred_program = fluid.Program()
...@@ -115,7 +114,6 @@ class A3CModel(parl.Model): ...@@ -115,7 +114,6 @@ class A3CModel(parl.Model):
class A3CAgent(parl.Agent): class A3CAgent(parl.Agent):
def __init__(self, algorithm): def __init__(self, algorithm):
super(A3CAgent, self).__init__(algorithm) super(A3CAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self): def build_program(self):
self.predict_program = fluid.Program() self.predict_program = fluid.Program()
...@@ -213,7 +211,6 @@ class IMPALAModel(parl.Model): ...@@ -213,7 +211,6 @@ class IMPALAModel(parl.Model):
class IMPALAAgent(parl.Agent): class IMPALAAgent(parl.Agent):
def __init__(self, algorithm): def __init__(self, algorithm):
super(IMPALAAgent, self).__init__(algorithm) super(IMPALAAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self): def build_program(self):
self.predict_program = fluid.Program() self.predict_program = fluid.Program()
......
...@@ -25,7 +25,7 @@ class AgentBase(object): ...@@ -25,7 +25,7 @@ class AgentBase(object):
Args: Args:
algorithm (`AlgorithmBase`): an instance of `AlgorithmBase` algorithm (`AlgorithmBase`): an instance of `AlgorithmBase`
""" """
self.algorithm = algorithm self.alg = algorithm
def get_weights(self, model_ids=None): def get_weights(self, model_ids=None):
"""Get weights of the agent. """Get weights of the agent.
...@@ -44,7 +44,7 @@ class AgentBase(object): ...@@ -44,7 +44,7 @@ class AgentBase(object):
Returns: Returns:
(Dict): Dict of weights ({attribute name: numpy array/List/Dict}) (Dict): Dict of weights ({attribute name: numpy array/List/Dict})
""" """
return self.algorithm.get_weights(model_ids=model_ids) return self.alg.get_weights(model_ids=model_ids)
def set_weights(self, weights, model_ids=None): def set_weights(self, weights, model_ids=None):
"""Set weights of the agent with given weights. """Set weights of the agent with given weights.
...@@ -62,15 +62,15 @@ class AgentBase(object): ...@@ -62,15 +62,15 @@ class AgentBase(object):
whiose model_id in the `model_ids`. whiose model_id in the `model_ids`.
""" """
self.algorithm.set_weights(weights, model_ids=model_ids) self.alg.set_weights(weights, model_ids=model_ids)
def get_model_ids(self): def get_model_ids(self):
"""Get all model ids of the self.algorithm in the agent. """Get all model ids of the self.alg in the agent.
Returns: Returns:
List of model_id List of model_id
""" """
return self.algorithm.get_model_ids() return self.alg.get_model_ids()
@property @property
def model_ids(self): def model_ids(self):
......
...@@ -74,7 +74,6 @@ class Agent(AgentBase): ...@@ -74,7 +74,6 @@ class Agent(AgentBase):
assert isinstance(algorithm, Algorithm) assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm) super(Agent, self).__init__(algorithm)
self.alg = algorithm
self.gpu_id = 0 if machine_info.is_gpu_available() else -1 self.gpu_id = 0 if machine_info.is_gpu_available() else -1
self.build_program() self.build_program()
......
...@@ -54,13 +54,13 @@ class TestAgent(parl.Agent): ...@@ -54,13 +54,13 @@ class TestAgent(parl.Agent):
self.learn_program = fluid.Program() self.learn_program = fluid.Program()
with fluid.program_guard(self.predict_program): with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32') obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.algorithm.predict(obs) output = self.alg.predict(obs)
self.predict_output = [output] self.predict_output = [output]
with fluid.program_guard(self.learn_program): with fluid.program_guard(self.learn_program):
obs = layers.data(name='obs', shape=[10], dtype='float32') obs = layers.data(name='obs', shape=[10], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='float32') label = layers.data(name='label', shape=[1], dtype='float32')
cost = self.algorithm.learn(obs, label) cost = self.alg.learn(obs, label)
def learn(self, obs, label): def learn(self, obs, label):
output_np = self.fluid_executor.run( output_np = self.fluid_executor.run(
...@@ -80,16 +80,16 @@ class TestAgent(parl.Agent): ...@@ -80,16 +80,16 @@ class TestAgent(parl.Agent):
class AgentBaseTest(unittest.TestCase): class AgentBaseTest(unittest.TestCase):
def setUp(self): def setUp(self):
self.model = TestModel() self.model = TestModel()
self.algorithm = TestAlgorithm(self.model) self.alg = TestAlgorithm(self.model)
def test_agent(self): def test_agent(self):
agent = TestAgent(self.algorithm) agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
self.assertIsNotNone(output_np) self.assertIsNotNone(output_np)
def test_save(self): def test_save(self):
agent = TestAgent(self.algorithm) agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = 'model.ckpt' save_path1 = 'model.ckpt'
...@@ -100,7 +100,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -100,7 +100,7 @@ class AgentBaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(save_path2)) self.assertTrue(os.path.exists(save_path2))
def test_restore(self): def test_restore(self):
agent = TestAgent(self.algorithm) agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs) output_np = agent.predict(obs)
save_path1 = 'model.ckpt' save_path1 = 'model.ckpt'
...@@ -111,13 +111,13 @@ class AgentBaseTest(unittest.TestCase): ...@@ -111,13 +111,13 @@ class AgentBaseTest(unittest.TestCase):
np.testing.assert_equal(current_output, previous_output) np.testing.assert_equal(current_output, previous_output)
# a new agent instance # a new agent instance
another_agent = TestAgent(self.algorithm) another_agent = TestAgent(self.alg)
another_agent.restore(save_path1) another_agent.restore(save_path1)
current_output = another_agent.predict(obs) current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output) np.testing.assert_equal(current_output, previous_output)
def test_compiled_restore(self): def test_compiled_restore(self):
agent = TestAgent(self.algorithm) agent = TestAgent(self.alg)
agent.learn_program = parl.compile(agent.learn_program) agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32') obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs) previous_output = agent.predict(obs)
...@@ -126,7 +126,7 @@ class AgentBaseTest(unittest.TestCase): ...@@ -126,7 +126,7 @@ class AgentBaseTest(unittest.TestCase):
agent.restore(save_path1) agent.restore(save_path1)
# a new agent instance # a new agent instance
another_agent = TestAgent(self.algorithm) another_agent = TestAgent(self.alg)
another_agent.learn_program = parl.compile(another_agent.learn_program) another_agent.learn_program = parl.compile(another_agent.learn_program)
another_agent.restore(save_path1) another_agent.restore(save_path1)
current_output = another_agent.predict(obs) current_output = another_agent.predict(obs)
......
...@@ -125,9 +125,8 @@ class AgentBaseTest(unittest.TestCase): ...@@ -125,9 +125,8 @@ class AgentBaseTest(unittest.TestCase):
def test_get_weights_with_model_ids(self): def test_get_weights_with_model_ids(self):
weights = self.agent1.get_weights(model_ids=[ weights = self.agent1.get_weights(model_ids=[
self.agent1.algorithm.model1.model_id, self.agent1.algorithm. self.agent1.alg.model1.model_id, self.agent1.alg.model_list2[0].
model_list2[0].model_id, self.agent1.algorithm.model_dict2['k1']. model_id, self.agent1.alg.model_dict2['k1'].model_id
model_id
]) ])
expected_dict = { expected_dict = {
'model1': 1, 'model1': 1,
...@@ -163,22 +162,22 @@ class AgentBaseTest(unittest.TestCase): ...@@ -163,22 +162,22 @@ class AgentBaseTest(unittest.TestCase):
self.agent1.set_weights( self.agent1.set_weights(
new_weights, new_weights,
model_ids=[ model_ids=[
self.agent1.algorithm.model1.model_id, self.agent1.alg.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id, self.agent1.alg.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id self.agent1.alg.model_dict2['k1'].model_id
]) ])
self.assertDictEqual(self.agent1.get_weights(), expected_dict) self.assertDictEqual(self.agent1.get_weights(), expected_dict)
def test_get_and_set_weights_between_agents_with_model_ids(self): def test_get_and_set_weights_between_agents_with_model_ids(self):
agent1_model_ids = [ agent1_model_ids = [
self.agent1.algorithm.model1.model_id, self.agent1.alg.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id, self.agent1.alg.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id self.agent1.alg.model_dict2['k1'].model_id
] ]
agent2_model_ids = [ agent2_model_ids = [
self.agent2.algorithm.model1.model_id, self.agent2.alg.model1.model_id,
self.agent2.algorithm.model_list2[0].model_id, self.agent2.alg.model_list2[0].model_id,
self.agent2.algorithm.model_dict2['k1'].model_id self.agent2.alg.model_dict2['k1'].model_id
] ]
new_weights = { new_weights = {
'model1': -1, 'model1': -1,
......
...@@ -52,7 +52,7 @@ class Agent(AgentBase): ...@@ -52,7 +52,7 @@ class Agent(AgentBase):
Public Functions: Public Functions:
- ``sample``: return a noisy action to perform exploration according to the policy. - ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an estimate Q function given current observation. - ``predict``: return an estimate Q function given current observation.
- ``learn``: update the parameters of self.algorithm. - ``learn``: update the parameters of self.alg.
- ``save``: save parameters of the ``agent`` to a given path. - ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path. - ``restore``: restore previous saved parameters from a given path.
...@@ -64,7 +64,7 @@ class Agent(AgentBase): ...@@ -64,7 +64,7 @@ class Agent(AgentBase):
""". """.
Args: Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.algorithm`. algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`.
device (torch.device): specify which GPU/CPU to be used. device (torch.device): specify which GPU/CPU to be used.
""" """
...@@ -98,10 +98,10 @@ class Agent(AgentBase): ...@@ -98,10 +98,10 @@ class Agent(AgentBase):
Args: Args:
save_path(str): where to save the parameters. save_path(str): where to save the parameters.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model. model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
Raises: Raises:
ValueError: if model is None and self.algorithm.model does not exist. ValueError: if model is None and self.alg.model does not exist.
Example: Example:
...@@ -112,7 +112,7 @@ class Agent(AgentBase): ...@@ -112,7 +112,7 @@ class Agent(AgentBase):
""" """
if model is None: if model is None:
model = self.algorithm.model model = self.alg.model
sep = os.sep sep = os.sep
dirname = sep.join(save_path.split(sep)[:-1]) dirname = sep.join(save_path.split(sep)[:-1])
if dirname != '' and not os.path.exists(dirname): if dirname != '' and not os.path.exists(dirname):
...@@ -126,10 +126,10 @@ class Agent(AgentBase): ...@@ -126,10 +126,10 @@ class Agent(AgentBase):
Args: Args:
save_path(str): path where parameters were previously saved. save_path(str): path where parameters were previously saved.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model. model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
Raises: Raises:
ValueError: if model is None and self.algorithm does not exist. ValueError: if model is None and self.alg does not exist.
Example: Example:
...@@ -142,6 +142,6 @@ class Agent(AgentBase): ...@@ -142,6 +142,6 @@ class Agent(AgentBase):
""" """
if model is None: if model is None:
model = self.algorithm.model model = self.alg.model
checkpoint = torch.load(save_path) checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
...@@ -57,10 +57,10 @@ class TestAgent(parl.Agent): ...@@ -57,10 +57,10 @@ class TestAgent(parl.Agent):
super(TestAgent, self).__init__(algorithm) super(TestAgent, self).__init__(algorithm)
def learn(self, obs, label): def learn(self, obs, label):
cost = self.algorithm.learn(obs, label) cost = self.alg.learn(obs, label)
def predict(self, obs): def predict(self, obs):
return self.algorithm.predict(obs) return self.alg.predict(obs)
class AgentBaseTest(unittest.TestCase): class AgentBaseTest(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册