提交 4abc0534 编写于 作者: L LI Yunxiang 提交者: Bo Zhou

add pytorch a2c (#167)

* add pytorch a2c

* add set/get_weights test & copyright

* yapf....

* Update model_base_test_torch.py

* update

* Delete banma.py

* Update model_base_test_torch.py

* update

* Update model.py

* update torch tests

* Update model_base_test_torch.py
上级 7c406386
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
config = {
#========== remote config ==========
'master_address': 'localhost:8010',
#========== env config ==========
'env_name': 'BreakoutNoFrameskip-v4',
'env_dim': 84,
#========== actor config ==========
'actor_num': 5,
'env_num': 5,
'sample_batch_steps': 20,
#========== learner config ==========
'max_sample_steps': int(1e7),
'gamma': 0.99,
'lambda': 1.0,
# start learning rate
'start_lr': 0.001,
'entropy_coeff_scheduler': [(0, -0.01)],
'vf_loss_coeff': 0.5,
'get_remote_metrics_interval': 10,
'log_metrics_interval_s': 10,
'entropy_coeff': -0.05,
'learning_rate': 3e-4
}
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import gym
import parl
import torch
import numpy as np
from collections import defaultdict
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from parl.env.vector_env import VectorEnv
from parl.utils.rl_utils import calc_gae
from atari_model import ActorCritic
from parl.algorithms import A2C
from atari_agent import Agent
@parl.remote_class
class Actor(object):
def __init__(self, config):
# the cluster may not have gpu
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
self.actor_cuda = False
self.config = config
self.envs = []
for _ in range(config['env_num']):
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env)
self.vector_env = VectorEnv(self.envs)
self.obs_batch = self.vector_env.reset()
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = ActorCritic(act_dim)
if self.actor_cuda:
model = model.cuda()
algorithm = A2C(model, config)
self.agent = Agent(algorithm, config)
def sample(self):
''' Interact with the environments lambda times
'''
sample_data = defaultdict(list)
env_sample_data = {}
for env_id in range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list)
for i in range(self.config['sample_batch_steps']):
self.obs_batch = np.stack(self.obs_batch)
self.obs_batch = torch.from_numpy(self.obs_batch).float()
if self.actor_cuda:
self.obs_batch = self.obs_batch.cuda()
action_batch, value_batch = self.agent.sample(self.obs_batch)
next_obs_batch, reward_batch, done_batch, info_batch = self.vector_env.step(
action_batch.cpu().numpy())
for env_id in range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(
self.obs_batch[env_id].cpu().numpy())
env_sample_data[env_id]['actions'].append(
action_batch[env_id].item())
env_sample_data[env_id]['rewards'].append(reward_batch[env_id])
env_sample_data[env_id]['dones'].append(done_batch[env_id])
env_sample_data[env_id]['values'].append(
value_batch[env_id].item())
if done_batch[
env_id] or i == self.config['sample_batch_steps'] - 1:
next_value = 0
if not done_batch[env_id]:
next_obs = np.expand_dims(next_obs_batch[env_id], 0)
next_obs = torch.from_numpy(next_obs).float()
if self.actor_cuda:
next_obs = next_obs.cuda()
next_value = self.agent.value(next_obs).item()
values = env_sample_data[env_id]['values']
rewards = env_sample_data[env_id]['rewards']
advantages = calc_gae(rewards, values, next_value,
self.config['gamma'],
self.config['lambda'])
target_values = advantages + values
sample_data['obs'].extend(env_sample_data[env_id]['obs'])
sample_data['actions'].extend(
env_sample_data[env_id]['actions'])
sample_data['advantages'].extend(advantages)
sample_data['target_values'].extend(target_values)
env_sample_data[env_id] = defaultdict(list)
self.obs_batch = next_obs_batch
for key in sample_data:
sample_data[key] = np.stack(sample_data[key])
return sample_data
def compute_target(self, v_final, r_lst, mask_lst):
G = v_final.reshape(-1)
td_target = list()
for r, mask in zip(r_lst[::-1], mask_lst[::-1]):
G = r + self.config['gamma'] * G * mask
td_target.append(G)
return torch.tensor(td_target[::-1]).float()
def get_metrics(self):
metrics = defaultdict(list)
for env in self.envs:
monitor = get_wrapper_by_cls(env, MonitorEnv)
if monitor is not None:
for episode_rewards, episode_steps in monitor.next_episode_results(
):
metrics['episode_rewards'].append(episode_rewards)
metrics['episode_steps'].append(episode_steps)
return metrics
def set_weights(self, params):
self.agent.set_weights(params)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import parl
# torch use full CPU by default, which will decrease the performance. Use one thread for one actor here.
torch.set_num_threads(1)
class Agent(parl.Agent):
def __init__(self, algorithm, config):
super(Agent, self).__init__(algorithm)
self.obs_shape = config['obs_shape']
def sample(self, obs):
sample_actions, values = self.algorithm.sample(obs)
return sample_actions, values
def predict(self, obs):
predict_actions = self.algorithm.predict(obs)
return predict_actions
def value(self, obs):
values = self.algorithm.value(obs)
return values
def learn(self, obs, actions, advantages, target_values):
total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.algorithm.learn(
obs, actions, advantages, target_values)
return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import parl
class ActorCritic(parl.Model):
def __init__(self, act_dim):
super(ActorCritic, self).__init__()
self.conv1 = nn.Conv2d(
in_channels=4, out_channels=32, kernel_size=8, stride=4, padding=2)
self.conv2 = nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=4,
stride=2,
padding=2)
self.conv3 = nn.Conv2d(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1)
self.fc = nn.Linear(7744, 512)
self.fc_pi = nn.Linear(512, act_dim)
self.fc_v = nn.Linear(512, 1)
def policy(self, x, softmax_dim=1):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
logits = self.fc_pi(x)
prob = F.softmax(logits, dim=softmax_dim)
return prob
def value(self, x):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
values = self.fc_v(x)
return values
def policy_and_value(self, x, softmax_dim=1):
x = x / 255.0
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = F.relu(self.conv3(x))
x = torch.flatten(x, start_dim=1)
x = F.relu(self.fc(x))
values = self.fc_v(x)
logits = self.fc_pi(x)
prob = F.softmax(logits, dim=softmax_dim)
return prob, values
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import os
import gym
import six
import queue
import parl
import time
import threading
import numpy as np
from collections import defaultdict
from parl.env.atari_wrappers import wrap_deepmind
from parl.utils.window_stat import WindowStat
from parl.utils.time_stat import TimeStat
from parl.utils import machine_info
from parl.utils import logger, get_gpu_count, tensorboard
from parl.algorithms import A2C
from atari_model import ActorCritic
from atari_agent import Agent
from actor import Actor
import time
from statistics import mean
class Learner(object):
def __init__(self, config, cuda):
self.cuda = cuda
self.config = config
env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
obs_shape = env.observation_space.shape
act_dim = env.action_space.n
self.config['obs_shape'] = obs_shape
self.config['act_dim'] = act_dim
model = ActorCritic(act_dim)
if self.cuda:
model = model.cuda()
algorithm = A2C(model, config)
self.agent = Agent(algorithm, config)
if machine_info.is_gpu_available():
assert get_gpu_count() == 1, 'Only support training in single GPU,\
Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .'
else:
os.environ['CPU_NUM'] = str(1)
#========== Learner ==========
self.total_loss_stat = WindowStat(100)
self.pi_loss_stat = WindowStat(100)
self.vf_loss_stat = WindowStat(100)
self.entropy_stat = WindowStat(100)
self.lr = None
self.entropy_coeff = None
self.learn_time_stat = TimeStat(100)
self.start_time = None
#========== Remote Actor ===========
self.remote_count = 0
self.sample_total_steps = 0
self.sample_data_queue = queue.Queue()
self.remote_metrics_queue = queue.Queue()
self.params_queues = []
self.create_actors()
def create_actors(self):
parl.connect(self.config['master_address'])
logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num']))
for i in six.moves.range(self.config['actor_num']):
params_queue = queue.Queue()
self.params_queues.append(params_queue)
self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count))
remote_thread = threading.Thread(
target=self.run_remote_sample, args=(params_queue, ))
remote_thread.setDaemon(True)
remote_thread.start()
logger.info('All remote actors are ready, begin to learn.')
self.start_time = time.time()
def run_remote_sample(self, params_queue):
remote_actor = Actor(self.config)
cnt = 0
while True:
latest_params = params_queue.get()
remote_actor.set_weights(latest_params)
batch = remote_actor.sample()
self.sample_data_queue.put(batch)
cnt += 1
if cnt % self.config['get_remote_metrics_interval'] == 0:
metrics = remote_actor.get_metrics()
if metrics:
self.remote_metrics_queue.put(metrics)
def step(self):
latest_params = self.agent.get_weights()
for params_queue in self.params_queues:
params_queue.put(latest_params)
train_batch = defaultdict(list)
for i in range(self.config['actor_num']):
sample_data = self.sample_data_queue.get()
for key, value in sample_data.items():
train_batch[key].append(value)
self.sample_total_steps += len(sample_data['obs'])
for key, value in train_batch.items():
train_batch[key] = np.concatenate(value)
train_batch[key] = torch.tensor(train_batch[key]).float()
if self.cuda:
train_batch[key] = train_batch[key].cuda()
with self.learn_time_stat:
total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn(
obs=train_batch['obs'],
actions=train_batch['actions'],
advantages=train_batch['advantages'],
target_values=train_batch['target_values'],
)
self.total_loss_stat.add(total_loss.item())
self.pi_loss_stat.add(pi_loss.item())
self.vf_loss_stat.add(vf_loss.item())
self.entropy_stat.add(entropy.item())
self.lr = lr
self.entropy_coeff = entropy_coeff
def log_metrics(self):
""" Log metrics of learner and actors
"""
if self.start_time is None:
return
metrics = []
while True:
try:
metric = self.remote_metrics_queue.get_nowait()
metrics.append(metric)
except queue.Empty:
break
episode_rewards, episode_steps = [], []
for x in metrics:
episode_rewards.extend(x['episode_rewards'])
episode_steps.extend(x['episode_steps'])
max_episode_rewards, mean_episode_rewards, min_episode_rewards, \
max_episode_steps, mean_episode_steps, min_episode_steps =\
None, None, None, None, None, None
if episode_rewards:
mean_episode_rewards = np.mean(np.array(episode_rewards).flatten())
max_episode_rewards = np.max(np.array(episode_rewards).flatten())
min_episode_rewards = np.min(np.array(episode_rewards).flatten())
mean_episode_steps = np.mean(np.array(episode_steps).flatten())
max_episode_steps = np.max(np.array(episode_steps).flatten())
min_episode_steps = np.min(np.array(episode_steps).flatten())
metric = {
'Sample steps': self.sample_total_steps,
'max_episode_rewards': max_episode_rewards,
'mean_episode_rewards': mean_episode_rewards,
'min_episode_rewards': min_episode_rewards,
'max_episode_steps': max_episode_steps,
'mean_episode_steps': mean_episode_steps,
'min_episode_steps': min_episode_steps,
'total_loss': self.total_loss_stat.mean,
'pi_loss': self.pi_loss_stat.mean,
'vf_loss': self.vf_loss_stat.mean,
'entropy': self.entropy_stat.mean,
'learn_time_s': self.learn_time_stat.mean,
'elapsed_time_s': int(time.time() - self.start_time),
'lr': self.lr,
'entropy_coeff': self.entropy_coeff,
}
if metric['mean_episode_rewards'] is not None:
tensorboard.add_scalar('train/mean_reward',
metric['mean_episode_rewards'],
self.sample_total_steps)
tensorboard.add_scalar('train/total_loss', metric['total_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/pi_loss', metric['pi_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/vf_loss', metric['vf_loss'],
self.sample_total_steps)
tensorboard.add_scalar('train/entropy', metric['entropy'],
self.sample_total_steps)
tensorboard.add_scalar('train/learn_rate', metric['lr'],
self.sample_total_steps)
logger.info(metric)
def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps']
if __name__ == '__main__':
from a2c_config import config
cuda = torch.cuda.is_available()
learner = Learner(config, cuda)
assert config['log_metrics_interval_s'] > 0
while not learner.should_stop():
start = time.time()
while time.time() - start < config['log_metrics_interval_s']:
learner.step()
learner.log_metrics()
...@@ -22,10 +22,10 @@ import torch.nn as nn ...@@ -22,10 +22,10 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from parl.core.torch.agent import Agent import parl
class AtariAgent(Agent): class AtariAgent(parl.Agent):
"""Base class of the Agent. """Base class of the Agent.
Args: Args:
......
...@@ -16,10 +16,10 @@ import torch ...@@ -16,10 +16,10 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from parl.core.torch.model import Model import parl
class AtariModel(Model): class AtariModel(parl.Model):
"""CNN network used in TensorPack examples. """CNN network used in TensorPack examples.
Args: Args:
......
...@@ -14,3 +14,4 @@ ...@@ -14,3 +14,4 @@
from parl.algorithms.torch.ddqn import * from parl.algorithms.torch.ddqn import *
from parl.algorithms.torch.dqn import * from parl.algorithms.torch.dqn import *
from parl.algorithms.torch.a2c import *
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch.distributions import Categorical
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
from random import random, randint
import parl
from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler
__all__ = ['A2C']
class A2C(parl.Algorithm):
def __init__(self, model, config, hyperparas=None):
assert isinstance(config['vf_loss_coeff'], (int, float))
self.model = model
self.vf_loss_coeff = config['vf_loss_coeff']
self.optimizer = optim.Adam(
self.model.parameters(), lr=config['learning_rate'])
self.config = config
self.lr_scheduler = LinearDecayScheduler(config['start_lr'],
config['max_sample_steps'])
self.entropy_coeff_scheduler = PiecewiseScheduler(
config['entropy_coeff_scheduler'])
def learn(self, obs, actions, advantages, target_values):
prob = self.model.policy(obs, softmax_dim=1)
policy_distri = Categorical(prob)
actions_log_probs = policy_distri.log_prob(actions)
# The policy gradient loss
pi_loss = -((actions_log_probs * advantages).sum())
# The value function loss
values = self.model.value(obs).reshape(-1)
delta = values - target_values
vf_loss = 0.5 * torch.mul(delta, delta).sum()
# The entropy loss (We want to maximize entropy, so entropy_ceoff < 0)
policy_entropy = policy_distri.entropy()
entropy = policy_entropy.sum()
lr = self.lr_scheduler.step(step_num=obs.shape[0])
entropy_coeff = self.entropy_coeff_scheduler.step()
total_loss = pi_loss + vf_loss * self.vf_loss_coeff + entropy * entropy_coeff
for param_group in self.optimizer.param_groups:
param_group['lr'] = lr
total_loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff
def sample(self, obs):
prob, values = self.model.policy_and_value(obs)
sample_actions = Categorical(prob).sample()
return sample_actions, values
def predict(self, obs):
prob = self.model.policy(obs)
_, predict_actions = prob.max(-1)
return predict_actions
def value(self, obs):
values = self.model.value(obs)
return values
...@@ -19,13 +19,13 @@ import copy ...@@ -19,13 +19,13 @@ import copy
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from parl.core.torch.algorithm import Algorithm import parl
import numpy as np import numpy as np
__all__ = ['DDQN'] __all__ = ['DDQN']
class DDQN(Algorithm): class DDQN(parl.Algorithm):
def __init__(self, model, gamma=None, lr=None): def __init__(self, model, gamma=None, lr=None):
""" Double DQN algorithm """ Double DQN algorithm
......
...@@ -19,13 +19,13 @@ import copy ...@@ -19,13 +19,13 @@ import copy
import torch import torch
import torch.optim as optim import torch.optim as optim
import torch.nn.functional as F import torch.nn.functional as F
from parl.core.torch.algorithm import Algorithm import parl
import numpy as np import numpy as np
__all__ = ['DQN'] __all__ = ['DQN']
class DQN(Algorithm): class DQN(parl.Algorithm):
def __init__(self, model, gamma=None, lr=None): def __init__(self, model, gamma=None, lr=None):
""" DQN algorithm """ DQN algorithm
......
...@@ -39,7 +39,7 @@ class AlgorithmBase(object): ...@@ -39,7 +39,7 @@ class AlgorithmBase(object):
Args: Args:
model_ids (List/Set): list/set of model_id, will only return weights of models model_ids (List/Set): list/set of model_id, will only return weights of models
whiose model_id in the `model_ids`. whose model_id in the `model_ids`.
Returns: Returns:
Dict of weights ({attribute name: numpy array/List/Dict}) Dict of weights ({attribute name: numpy array/List/Dict})
......
...@@ -667,13 +667,8 @@ class ModelBaseTest(unittest.TestCase): ...@@ -667,13 +667,8 @@ class ModelBaseTest(unittest.TestCase):
params = self.model.get_weights() params = self.model.get_weights()
try: with self.assertRaises(AssertionError):
self.model.set_weights(params[1:]) self.model.set_weights(params[1:])
except:
# expected
return
assert False
def test_set_weights_with_wrong_params_shape(self): def test_set_weights_with_wrong_params_shape(self):
pred_program = fluid.Program() pred_program = fluid.Program()
...@@ -691,14 +686,9 @@ class ModelBaseTest(unittest.TestCase): ...@@ -691,14 +686,9 @@ class ModelBaseTest(unittest.TestCase):
x = np.random.random(size=(1, 4)).astype('float32') x = np.random.random(size=(1, 4)).astype('float32')
try: with self.assertRaises(fluid.core_avx.EnforceNotMet):
outputs = self.executor.run( self.executor.run(
pred_program, feed={'obs': x}, fetch_list=[model_output]) pred_program, feed={'obs': x}, fetch_list=[model_output])
except:
# expected
return
assert False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -52,7 +52,7 @@ class Agent(AgentBase): ...@@ -52,7 +52,7 @@ class Agent(AgentBase):
Public Functions: Public Functions:
- ``sample``: return a noisy action to perform exploration according to the policy. - ``sample``: return a noisy action to perform exploration according to the policy.
- ``predict``: return an estimate Q function given current observation. - ``predict``: return an estimate Q function given current observation.
- ``learn``: update the parameters of self.alg. - ``learn``: update the parameters of self.algorithm.
- ``save``: save parameters of the ``agent`` to a given path. - ``save``: save parameters of the ``agent`` to a given path.
- ``restore``: restore previous saved parameters from a given path. - ``restore``: restore previous saved parameters from a given path.
...@@ -60,21 +60,17 @@ class Agent(AgentBase): ...@@ -60,21 +60,17 @@ class Agent(AgentBase):
- allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``. - allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``.
""" """
def __init__(self, algorithm, device): def __init__(self, algorithm):
""". """.
Args: Args:
algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`. algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.algorithm`.
device (torch.device): specify which GPU/CPU to be used. device (torch.device): specify which GPU/CPU to be used.
""" """
assert isinstance(algorithm, Algorithm) assert isinstance(algorithm, Algorithm)
super(Agent, self).__init__(algorithm) super(Agent, self).__init__(algorithm)
self.alg = algorithm
self.device = torc.device('cuda' if torch.cuda.
is_available() else 'cpu')
def learn(self, *args, **kwargs): def learn(self, *args, **kwargs):
"""The training interface for ``Agent``. """The training interface for ``Agent``.
...@@ -102,10 +98,10 @@ class Agent(AgentBase): ...@@ -102,10 +98,10 @@ class Agent(AgentBase):
Args: Args:
save_path(str): where to save the parameters. save_path(str): where to save the parameters.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
Raises: Raises:
ValueError: if model is None and self.alg.model does not exist. ValueError: if model is None and self.algorithm.model does not exist.
Example: Example:
...@@ -116,7 +112,7 @@ class Agent(AgentBase): ...@@ -116,7 +112,7 @@ class Agent(AgentBase):
""" """
if model is None: if model is None:
model = self.alg.model model = self.algorithm.model
dirname = '/'.join(save_path.split('/')[:-1]) dirname = '/'.join(save_path.split('/')[:-1])
if not os.path.exists(dirname): if not os.path.exists(dirname):
os.makedirs(dirname) os.makedirs(dirname)
...@@ -129,10 +125,10 @@ class Agent(AgentBase): ...@@ -129,10 +125,10 @@ class Agent(AgentBase):
Args: Args:
save_path(str): path where parameters were previously saved. save_path(str): path where parameters were previously saved.
model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model.
Raises: Raises:
ValueError: if model is None and self.alg does not exist. ValueError: if model is None and self.algorithm does not exist.
Example: Example:
...@@ -145,6 +141,6 @@ class Agent(AgentBase): ...@@ -145,6 +141,6 @@ class Agent(AgentBase):
""" """
if model is None: if model is None:
model = self.alg.model model = self.algorithm.model
checkpoint = torch.load(save_path) checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
...@@ -62,7 +62,7 @@ class Algorithm(AlgorithmBase): ...@@ -62,7 +62,7 @@ class Algorithm(AlgorithmBase):
assert isinstance(model, Model) assert isinstance(model, Model)
self.model = model self.model = model
def get_weights(self): def get_weights(self, model_ids=None):
""" Get weights of self.model. """ Get weights of self.model.
Returns: Returns:
...@@ -71,7 +71,7 @@ class Algorithm(AlgorithmBase): ...@@ -71,7 +71,7 @@ class Algorithm(AlgorithmBase):
""" """
return self.model.get_weights() return self.model.get_weights()
def set_weights(self, params): def set_weights(self, params, model_ids=None):
""" Set weights from ``get_weights`` to the model. """ Set weights from ``get_weights`` to the model.
Args: Args:
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
import torch.nn as nn import torch.nn as nn
from parl.core.model_base import ModelBase from parl.core.model_base import ModelBase
...@@ -116,7 +117,10 @@ class Model(nn.Module, ModelBase): ...@@ -116,7 +117,10 @@ class Model(nn.Module, ModelBase):
Returns: a Python list containing the parameters of current model. Returns: a Python list containing the parameters of current model.
""" """
return list(self.parameters()) weights = self.state_dict()
for key in weights.keys():
weights[key] = weights[key].cpu().numpy()
return weights
def set_weights(self, weights): def set_weights(self, weights):
"""Copy parameters from ``set_weights()`` to the model. """Copy parameters from ``set_weights()`` to the model.
...@@ -124,8 +128,6 @@ class Model(nn.Module, ModelBase): ...@@ -124,8 +128,6 @@ class Model(nn.Module, ModelBase):
Args: Args:
weights (list): a Python list containing the parameters. weights (list): a Python list containing the parameters.
""" """
assert len(weights) == len(list(self.parameters())), \ for key in weights.keys():
'size of input weights should be same as weights number of current model' weights[key] = torch.from_numpy(weights[key])
self.load_state_dict(weights)
for var, weight in zip(self.parameters(), weights):
var.data.copy_(weight.data)
...@@ -37,7 +37,7 @@ class TestModel(parl.Model): ...@@ -37,7 +37,7 @@ class TestModel(parl.Model):
class TestAlgorithm(parl.Algorithm): class TestAlgorithm(parl.Algorithm):
def __init__(self, model): def __init__(self, model):
self.model = model super(TestAlgorithm, self).__init__(model)
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
def predict(self, obs): def predict(self, obs):
...@@ -54,13 +54,13 @@ class TestAlgorithm(parl.Algorithm): ...@@ -54,13 +54,13 @@ class TestAlgorithm(parl.Algorithm):
class TestAgent(parl.Agent): class TestAgent(parl.Agent):
def __init__(self, algorithm): def __init__(self, algorithm):
self.alg = algorithm super(TestAgent, self).__init__(algorithm)
def learn(self, obs, label): def learn(self, obs, label):
cost = self.alg.learn(obs, label) cost = self.algorithm.learn(obs, label)
def predict(self, obs): def predict(self, obs):
return self.alg.predict(obs) return self.algorithm.predict(obs)
class AgentBaseTest(unittest.TestCase): class AgentBaseTest(unittest.TestCase):
...@@ -95,6 +95,11 @@ class AgentBaseTest(unittest.TestCase): ...@@ -95,6 +95,11 @@ class AgentBaseTest(unittest.TestCase):
current_output = agent.predict(obs).detach().cpu().numpy() current_output = agent.predict(obs).detach().cpu().numpy()
np.testing.assert_equal(current_output, previous_output) np.testing.assert_equal(current_output, previous_output)
def test_weights(self):
agent = TestAgent(self.alg)
weight = agent.get_weights()
agent.set_weights(weight)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -16,6 +16,7 @@ import numpy as np ...@@ -16,6 +16,7 @@ import numpy as np
import unittest import unittest
import os import os
from copy import deepcopy from copy import deepcopy
from collections import OrderedDict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -44,6 +45,7 @@ class ModelBaseTest(unittest.TestCase): ...@@ -44,6 +45,7 @@ class ModelBaseTest(unittest.TestCase):
self.model = TestModel() self.model = TestModel()
self.target_model = TestModel() self.target_model = TestModel()
self.target_model2 = TestModel() self.target_model2 = TestModel()
self.target_model3 = TestModel()
gpu_count = get_gpu_count() gpu_count = get_gpu_count()
device = torch.device('cuda' if gpu_count else 'cpu') device = torch.device('cuda' if gpu_count else 'cpu')
...@@ -282,22 +284,18 @@ class ModelBaseTest(unittest.TestCase): ...@@ -282,22 +284,18 @@ class ModelBaseTest(unittest.TestCase):
params = self.model.get_weights() params = self.model.get_weights()
expected_params = list(self.model.parameters()) expected_params = list(self.model.parameters())
self.assertEqual(len(params), len(expected_params)) self.assertEqual(len(params), len(expected_params))
for param in params: for i, key in enumerate(params):
flag = False self.assertLess(
for expected_param in expected_params: (params[key].sum().item() - expected_params[i].sum().item()),
if param.sum().item() - expected_param.sum().item() < 1e-5: 1e-5)
flag = True
break
self.assertTrue(flag)
def test_set_weights(self): def test_set_weights(self):
params = self.model.get_weights() params = self.model.get_weights()
new_params = [x + 1.0 for x in params] self.target_model3.set_weights(params)
self.model.set_weights(new_params) for i, j in zip(params.values(),
self.target_model3.get_weights().values()):
for x, y in list(zip(new_params, self.model.get_weights())): self.assertLessEqual(abs(i.sum().item() - j.sum().item()), 1e-3)
self.assertEqual(x.sum().item(), y.sum().item())
def test_set_weights_between_different_models(self): def test_set_weights_between_different_models(self):
model1 = TestModel() model1 = TestModel()
...@@ -323,20 +321,14 @@ class ModelBaseTest(unittest.TestCase): ...@@ -323,20 +321,14 @@ class ModelBaseTest(unittest.TestCase):
def test_set_weights_wrong_params_num(self): def test_set_weights_wrong_params_num(self):
params = self.model.get_weights() params = self.model.get_weights()
try: with self.assertRaises(TypeError):
self.model.set_weights(params[1:]) self.model.set_weights(params[1:])
except:
return
assert False
def test_set_weights_wrong_params_shape(self): def test_set_weights_wrong_params_shape(self):
params = self.model.get_weights() params = self.model.get_weights()
params.reverse() params['fc1.weight'] = params['fc2.bias']
try: with self.assertRaises(RuntimeError):
self.model.set_weights(params) self.model.set_weights(params)
except:
return
assert False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -53,20 +53,12 @@ class TestScheduler(unittest.TestCase): ...@@ -53,20 +53,12 @@ class TestScheduler(unittest.TestCase):
assert value == 0.3 assert value == 0.3
def test_PiecewiseScheduler_with_empty(self): def test_PiecewiseScheduler_with_empty(self):
try: with self.assertRaises(AssertionError):
scheduler = PiecewiseScheduler([]) scheduler = PiecewiseScheduler([])
except AssertionError:
# expected
return
assert False
def test_PiecewiseScheduler_with_incorrect_steps(self): def test_PiecewiseScheduler_with_incorrect_steps(self):
try: with self.assertRaises(AssertionError):
scheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)]) tscheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)])
except AssertionError:
# expected
return
assert False
def test_LinearDecayScheduler(self): def test_LinearDecayScheduler(self):
scheduler = LinearDecayScheduler(start_value=10, max_steps=10) scheduler = LinearDecayScheduler(start_value=10, max_steps=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册