未验证 提交 a50793e4 编写于 作者: R rical730 提交者: GitHub

fix self.alg (#325)

* fix self.alg

* torch agent initialization

* remove definition of self.alg in PPO

* replace self.algorithm with self.alg

* remove unnecessary definition of self.alg

* fix cn readme

* unittest

* yapf
上级 3b2da5d1
......@@ -18,7 +18,7 @@
# 框架结构
<img src=".github/abstractions.png" alt="abstractions" width="400"/>
PARL的目标是构建一个可以完复杂任务的智能体。以下是用户在逐步构建一个智能体的过程中需要了解到的结构:
PARL的目标是构建一个可以完复杂任务的智能体。以下是用户在逐步构建一个智能体的过程中需要了解到的结构:
### Model
`Model` 用来定义前向(`Forward`)网络,这通常是一个策略网络(`Policy Network`)或者一个值函数网络(`Value Function`),输入是当前环境状态(`State`)。
......
......@@ -90,7 +90,7 @@ class AlphaZeroAgent(parl.Agent):
Args:
examples: list of examples, each example is of form (board, pi, v)
"""
optimizer = optim.Adam(self.algorithm.model.parameters(), lr=args.lr)
optimizer = optim.Adam(self.alg.model.parameters(), lr=args.lr)
for epoch in range(args.epochs):
print('EPOCH ::: ' + str(epoch + 1))
......@@ -111,7 +111,7 @@ class AlphaZeroAgent(parl.Agent):
), target_pis.contiguous().cuda(), target_vs.contiguous(
).cuda()
total_loss, pi_loss, v_loss = self.algorithm.learn(
total_loss, pi_loss, v_loss = self.alg.learn(
boards, target_pis, target_vs, optimizer)
# record loss with tqdm
......@@ -132,7 +132,7 @@ class AlphaZeroAgent(parl.Agent):
board = board.contiguous().cuda()
board = board.view(1, self.board_x, self.board_y)
pi, v = self.algorithm.predict(board)
pi, v = self.alg.predict(board)
return pi.data.cpu().numpy()[0], v.data.cpu().numpy()[0]
......
......@@ -26,7 +26,7 @@ class CartpoleAgent(parl.Agent):
"""
def __init__(self, algorithm):
self.algorithm = algorithm
super(CartpoleAgent, self).__init__(algorithm)
self.device = torch.device("cuda" if torch.cuda.
is_available() else "cpu")
......@@ -40,7 +40,7 @@ class CartpoleAgent(parl.Agent):
action(int)
"""
obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs)
prob = self.alg.predict(obs).cpu()
prob = prob.data.numpy()
action = np.random.choice(len(prob), 1, p=prob)[0]
return action
......@@ -55,7 +55,7 @@ class CartpoleAgent(parl.Agent):
action(int)
"""
obs = torch.tensor(obs, device=self.device, dtype=torch.float)
prob = self.algorithm.predict(obs)
prob = self.alg.predict(obs)
_, action = prob.max(-1)
return action.item()
......@@ -75,5 +75,5 @@ class CartpoleAgent(parl.Agent):
action = torch.tensor(action, device=self.device, dtype=torch.long)
reward = torch.tensor(reward, device=self.device, dtype=torch.float)
loss = self.algorithm.learn(obs, action, reward)
loss = self.alg.learn(obs, action, reward)
return loss.item()
......@@ -25,19 +25,19 @@ class Agent(parl.Agent):
self.obs_shape = config['obs_shape']
def sample(self, obs):
sample_actions, values = self.algorithm.sample(obs)
sample_actions, values = self.alg.sample(obs)
return sample_actions, values
def predict(self, obs):
predict_actions = self.algorithm.predict(obs)
predict_actions = self.alg.predict(obs)
return predict_actions
def value(self, obs):
values = self.algorithm.value(obs)
values = self.alg.value(obs)
return values
def learn(self, obs, actions, advantages, target_values):
total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.algorithm.learn(
total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.alg.learn(
obs, actions, advantages, target_values)
return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff
......@@ -36,12 +36,12 @@ class AtariAgent(parl.Agent):
def __init__(self, algorithm, act_dim):
assert isinstance(act_dim, int)
super(AtariAgent, self).__init__(algorithm)
self.act_dim = act_dim
self.exploration = 1
self.global_step = 0
self.update_target_steps = 10000 // 4
self.alg = algorithm
self.device = torch.device('cuda' if torch.cuda.
is_available() else 'cpu')
......
......@@ -18,7 +18,7 @@ import torch
class MujocoAgent(parl.Agent):
def __init__(self, algorithm, device):
self.alg = algorithm
super(MujocoAgent, self).__init__(algorithm)
self.device = device
def predict(self, obs):
......
......@@ -31,7 +31,6 @@ class MujocoAgent(parl.Agent):
policy_learn_times=20,
value_learn_times=10,
value_batch_size=256):
self.alg = algorithm
self.obs_dim = obs_dim
self.act_dim = act_dim
assert loss_type == 'CLIP' or loss_type == 'KLPEN'
......
......@@ -33,7 +33,6 @@ class DQNModel(parl.Model):
class DQNAgent(parl.Agent):
def __init__(self, algorithm):
super(DQNAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self):
self.pred_program = fluid.Program()
......@@ -115,7 +114,6 @@ class A3CModel(parl.Model):
class A3CAgent(parl.Agent):
def __init__(self, algorithm):
super(A3CAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self):
self.predict_program = fluid.Program()
......@@ -213,7 +211,6 @@ class IMPALAModel(parl.Model):
class IMPALAAgent(parl.Agent):
def __init__(self, algorithm):
super(IMPALAAgent, self).__init__(algorithm)
self.alg = algorithm
def build_program(self):
self.predict_program = fluid.Program()
......
......@@ -25,7 +25,7 @@ class AgentBase(object):
Args:
algorithm (`AlgorithmBase`): an instance of `AlgorithmBase`
"""
self.algorithm = algorithm
self.alg = algorithm
def get_weights(self, model_ids=None):
"""Get weights of the agent.
......@@ -44,7 +44,7 @@ class AgentBase(object):
Returns:
(Dict): Dict of weights ({attribute name: numpy array/List/Dict})
"""
return self.algorithm.get_weights(model_ids=model_ids)
return self.alg.get_weights(model_ids=model_ids)
def set_weights(self, weights, model_ids=None):
"""Set weights of the agent with given weights.
......@@ -62,15 +62,15 @@ class AgentBase(object):
whiose model_id in the `model_ids`.
"""
self.algorithm.set_weights(weights, model_ids=model_ids)
self.alg.set_weights(weights, model_ids=model_ids)
def get_model_ids(self):
"""Get all model ids of the self.algorithm in the agent.
"""Get all model ids of the self.alg in the agent.
Returns:
List of model_id
"""
return self.algorithm.get_model_ids()
return self.alg.get_model_ids()
@property
def model_ids(self):
......
......@@ -74,7 +74,6 @@ class Agent(AgentBase):
assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm)
self.alg = algorithm
self.gpu_id = 0 if machine_info.is_gpu_available() else -1
self.build_program()
......
......@@ -54,13 +54,13 @@ class TestAgent(parl.Agent):
self.learn_program = fluid.Program()
with fluid.program_guard(self.predict_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
output = self.algorithm.predict(obs)
output = self.alg.predict(obs)
self.predict_output = [output]
with fluid.program_guard(self.learn_program):
obs = layers.data(name='obs', shape=[10], dtype='float32')
label = layers.data(name='label', shape=[1], dtype='float32')
cost = self.algorithm.learn(obs, label)
cost = self.alg.learn(obs, label)
def learn(self, obs, label):
output_np = self.fluid_executor.run(
......@@ -80,16 +80,16 @@ class TestAgent(parl.Agent):
class AgentBaseTest(unittest.TestCase):
def setUp(self):
self.model = TestModel()
self.algorithm = TestAlgorithm(self.model)
self.alg = TestAlgorithm(self.model)
def test_agent(self):
agent = TestAgent(self.algorithm)
agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
self.assertIsNotNone(output_np)
def test_save(self):
agent = TestAgent(self.algorithm)
agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = 'model.ckpt'
......@@ -100,7 +100,7 @@ class AgentBaseTest(unittest.TestCase):
self.assertTrue(os.path.exists(save_path2))
def test_restore(self):
agent = TestAgent(self.algorithm)
agent = TestAgent(self.alg)
obs = np.random.random([3, 10]).astype('float32')
output_np = agent.predict(obs)
save_path1 = 'model.ckpt'
......@@ -111,13 +111,13 @@ class AgentBaseTest(unittest.TestCase):
np.testing.assert_equal(current_output, previous_output)
# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent = TestAgent(self.alg)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
np.testing.assert_equal(current_output, previous_output)
def test_compiled_restore(self):
agent = TestAgent(self.algorithm)
agent = TestAgent(self.alg)
agent.learn_program = parl.compile(agent.learn_program)
obs = np.random.random([3, 10]).astype('float32')
previous_output = agent.predict(obs)
......@@ -126,7 +126,7 @@ class AgentBaseTest(unittest.TestCase):
agent.restore(save_path1)
# a new agent instance
another_agent = TestAgent(self.algorithm)
another_agent = TestAgent(self.alg)
another_agent.learn_program = parl.compile(another_agent.learn_program)
another_agent.restore(save_path1)
current_output = another_agent.predict(obs)
......
......@@ -125,9 +125,8 @@ class AgentBaseTest(unittest.TestCase):
def test_get_weights_with_model_ids(self):
weights = self.agent1.get_weights(model_ids=[
self.agent1.algorithm.model1.model_id, self.agent1.algorithm.
model_list2[0].model_id, self.agent1.algorithm.model_dict2['k1'].
model_id
self.agent1.alg.model1.model_id, self.agent1.alg.model_list2[0].
model_id, self.agent1.alg.model_dict2['k1'].model_id
])
expected_dict = {
'model1': 1,
......@@ -163,22 +162,22 @@ class AgentBaseTest(unittest.TestCase):
self.agent1.set_weights(
new_weights,
model_ids=[
self.agent1.algorithm.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id
self.agent1.alg.model1.model_id,
self.agent1.alg.model_list2[0].model_id,
self.agent1.alg.model_dict2['k1'].model_id
])
self.assertDictEqual(self.agent1.get_weights(), expected_dict)
def test_get_and_set_weights_between_agents_with_model_ids(self):
agent1_model_ids = [
self.agent1.algorithm.model1.model_id,
self.agent1.algorithm.model_list2[0].model_id,
self.agent1.algorithm.model_dict2['k1'].model_id
self.agent1.alg.model1.model_id,
self.agent1.alg.model_list2[0].model_id,
self.agent1.alg.model_dict2['k1'].model_id
]
agent2_model_ids = [
self.agent2.algorithm.model1.model_id,
self.agent2.algorithm.model_list2[0].model_id,
self.agent2.algorithm.model_dict2['k1'].model_id
self.agent2.alg.model1.model_id,
self.agent2.alg.model_list2[0].model_id,
self.agent2.alg.model_dict2['k1'].model_id
]
new_weights = {
'model1': -1,
......
......@@ -52,7 +52,7 @@ class Agent(AgentBase):
Public Functions:
- ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an estimate Q function given current observation.
- ``learn``: update the parameters of self.algorithm.
- ``learn``: update the parameters of self.alg.
- ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path.
......@@ -64,7 +64,7 @@ class Agent(AgentBase):
""".
Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.algorithm`.
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`.
device (torch.device): specify which GPU/CPU to be used.
"""
......@@ -98,10 +98,10 @@ class Agent(AgentBase):
Args:
save_path(str): where to save the parameters.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
Raises:
ValueError: if model is None and self.algorithm.model does not exist.
ValueError: if model is None and self.alg.model does not exist.
Example:
......@@ -112,7 +112,7 @@ class Agent(AgentBase):
"""
if model is None:
model = self.algorithm.model
model = self.alg.model
sep = os.sep
dirname = sep.join(save_path.split(sep)[:-1])
if dirname != '' and not os.path.exists(dirname):
......@@ -126,10 +126,10 @@ class Agent(AgentBase):
Args:
save_path(str): path where parameters were previously saved.
model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model.
Raises:
ValueError: if model is None and self.algorithm does not exist.
ValueError: if model is None and self.alg does not exist.
Example:
......@@ -142,6 +142,6 @@ class Agent(AgentBase):
"""
if model is None:
model = self.algorithm.model
model = self.alg.model
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint)
......@@ -57,10 +57,10 @@ class TestAgent(parl.Agent):
super(TestAgent, self).__init__(algorithm)
def learn(self, obs, label):
cost = self.algorithm.learn(obs, label)
cost = self.alg.learn(obs, label)
def predict(self, obs):
return self.algorithm.predict(obs)
return self.alg.predict(obs)
class AgentBaseTest(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册