diff --git a/benchmark/torch/a2c/a2c_config.py b/benchmark/torch/a2c/a2c_config.py new file mode 100644 index 0000000000000000000000000000000000000000..c8512f24408504aa83a0afd03277f7037e3dacee --- /dev/null +++ b/benchmark/torch/a2c/a2c_config.py @@ -0,0 +1,42 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +config = { + + #========== remote config ========== + 'master_address': 'localhost:8010', + + #========== env config ========== + 'env_name': 'BreakoutNoFrameskip-v4', + 'env_dim': 84, + + #========== actor config ========== + 'actor_num': 5, + 'env_num': 5, + 'sample_batch_steps': 20, + + #========== learner config ========== + 'max_sample_steps': int(1e7), + 'gamma': 0.99, + 'lambda': 1.0, + + # start learning rate + 'start_lr': 0.001, + 'entropy_coeff_scheduler': [(0, -0.01)], + 'vf_loss_coeff': 0.5, + 'get_remote_metrics_interval': 10, + 'log_metrics_interval_s': 10, + 'entropy_coeff': -0.05, + 'learning_rate': 3e-4 +} diff --git a/benchmark/torch/a2c/actor.py b/benchmark/torch/a2c/actor.py new file mode 100644 index 0000000000000000000000000000000000000000..178feedf68be7108b4602aa2fe4bd1a1183a8826 --- /dev/null +++ b/benchmark/torch/a2c/actor.py @@ -0,0 +1,142 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import gym +import parl +import torch +import numpy as np +from collections import defaultdict +from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls +from parl.env.vector_env import VectorEnv +from parl.utils.rl_utils import calc_gae + +from atari_model import ActorCritic +from parl.algorithms import A2C +from atari_agent import Agent + + +@parl.remote_class +class Actor(object): + def __init__(self, config): + # the cluster may not have gpu + os.environ['CUDA_VISIBLE_DEVICES'] = '0' + self.actor_cuda = False + self.config = config + + self.envs = [] + for _ in range(config['env_num']): + env = gym.make(config['env_name']) + env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') + self.envs.append(env) + self.vector_env = VectorEnv(self.envs) + + self.obs_batch = self.vector_env.reset() + + obs_shape = env.observation_space.shape + act_dim = env.action_space.n + + self.config['obs_shape'] = obs_shape + self.config['act_dim'] = act_dim + + model = ActorCritic(act_dim) + if self.actor_cuda: + model = model.cuda() + + algorithm = A2C(model, config) + self.agent = Agent(algorithm, config) + + def sample(self): + ''' Interact with the environments lambda times + ''' + sample_data = defaultdict(list) + + env_sample_data = {} + for env_id in range(self.config['env_num']): + env_sample_data[env_id] = defaultdict(list) + for i in range(self.config['sample_batch_steps']): + self.obs_batch = np.stack(self.obs_batch) + self.obs_batch = torch.from_numpy(self.obs_batch).float() + if self.actor_cuda: + self.obs_batch = self.obs_batch.cuda() + + action_batch, value_batch = self.agent.sample(self.obs_batch) + next_obs_batch, reward_batch, done_batch, info_batch = self.vector_env.step( + action_batch.cpu().numpy()) + + for env_id in range(self.config['env_num']): + env_sample_data[env_id]['obs'].append( + self.obs_batch[env_id].cpu().numpy()) + env_sample_data[env_id]['actions'].append( + action_batch[env_id].item()) + env_sample_data[env_id]['rewards'].append(reward_batch[env_id]) + env_sample_data[env_id]['dones'].append(done_batch[env_id]) + env_sample_data[env_id]['values'].append( + value_batch[env_id].item()) + + if done_batch[ + env_id] or i == self.config['sample_batch_steps'] - 1: + next_value = 0 + if not done_batch[env_id]: + next_obs = np.expand_dims(next_obs_batch[env_id], 0) + next_obs = torch.from_numpy(next_obs).float() + if self.actor_cuda: + next_obs = next_obs.cuda() + next_value = self.agent.value(next_obs).item() + + values = env_sample_data[env_id]['values'] + rewards = env_sample_data[env_id]['rewards'] + advantages = calc_gae(rewards, values, next_value, + self.config['gamma'], + self.config['lambda']) + target_values = advantages + values + + sample_data['obs'].extend(env_sample_data[env_id]['obs']) + sample_data['actions'].extend( + env_sample_data[env_id]['actions']) + sample_data['advantages'].extend(advantages) + sample_data['target_values'].extend(target_values) + + env_sample_data[env_id] = defaultdict(list) + + self.obs_batch = next_obs_batch + + for key in sample_data: + sample_data[key] = np.stack(sample_data[key]) + + return sample_data + + def compute_target(self, v_final, r_lst, mask_lst): + G = v_final.reshape(-1) + td_target = list() + + for r, mask in zip(r_lst[::-1], mask_lst[::-1]): + G = r + self.config['gamma'] * G * mask + td_target.append(G) + + return torch.tensor(td_target[::-1]).float() + + def get_metrics(self): + metrics = defaultdict(list) + for env in self.envs: + monitor = get_wrapper_by_cls(env, MonitorEnv) + if monitor is not None: + for episode_rewards, episode_steps in monitor.next_episode_results( + ): + metrics['episode_rewards'].append(episode_rewards) + metrics['episode_steps'].append(episode_steps) + return metrics + + def set_weights(self, params): + self.agent.set_weights(params) diff --git a/benchmark/torch/a2c/atari_agent.py b/benchmark/torch/a2c/atari_agent.py new file mode 100644 index 0000000000000000000000000000000000000000..123ad7370704e36bf2271f18cf7013a3e29c646f --- /dev/null +++ b/benchmark/torch/a2c/atari_agent.py @@ -0,0 +1,43 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import parl + +# torch use full CPU by default, which will decrease the performance. Use one thread for one actor here. +torch.set_num_threads(1) + + +class Agent(parl.Agent): + def __init__(self, algorithm, config): + super(Agent, self).__init__(algorithm) + self.obs_shape = config['obs_shape'] + + def sample(self, obs): + sample_actions, values = self.algorithm.sample(obs) + return sample_actions, values + + def predict(self, obs): + predict_actions = self.algorithm.predict(obs) + return predict_actions + + def value(self, obs): + values = self.algorithm.value(obs) + return values + + def learn(self, obs, actions, advantages, target_values): + total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff = self.algorithm.learn( + obs, actions, advantages, target_values) + + return total_loss, pi_loss, vf_losss, entropy, lr, entropy_coeff diff --git a/benchmark/torch/a2c/atari_model.py b/benchmark/torch/a2c/atari_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0929dc9cf52a9c8d74dc84d750e254d7b317d7a4 --- /dev/null +++ b/benchmark/torch/a2c/atari_model.py @@ -0,0 +1,84 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +import parl + + +class ActorCritic(parl.Model): + def __init__(self, act_dim): + super(ActorCritic, self).__init__() + self.conv1 = nn.Conv2d( + in_channels=4, out_channels=32, kernel_size=8, stride=4, padding=2) + self.conv2 = nn.Conv2d( + in_channels=32, + out_channels=64, + kernel_size=4, + stride=2, + padding=2) + self.conv3 = nn.Conv2d( + in_channels=64, + out_channels=64, + kernel_size=3, + stride=1, + padding=1) + self.fc = nn.Linear(7744, 512) + + self.fc_pi = nn.Linear(512, act_dim) + self.fc_v = nn.Linear(512, 1) + + def policy(self, x, softmax_dim=1): + x = x / 255.0 + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + + x = torch.flatten(x, start_dim=1) + x = F.relu(self.fc(x)) + + logits = self.fc_pi(x) + prob = F.softmax(logits, dim=softmax_dim) + + return prob + + def value(self, x): + x = x / 255.0 + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + + x = torch.flatten(x, start_dim=1) + x = F.relu(self.fc(x)) + values = self.fc_v(x) + + return values + + def policy_and_value(self, x, softmax_dim=1): + x = x / 255.0 + x = F.relu(self.conv1(x)) + x = F.relu(self.conv2(x)) + x = F.relu(self.conv3(x)) + + x = torch.flatten(x, start_dim=1) + x = F.relu(self.fc(x)) + + values = self.fc_v(x) + logits = self.fc_pi(x) + prob = F.softmax(logits, dim=softmax_dim) + + return prob, values diff --git a/benchmark/torch/a2c/train.py b/benchmark/torch/a2c/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f2985367f8304edb6bccc93f894a7d04f5f305c8 --- /dev/null +++ b/benchmark/torch/a2c/train.py @@ -0,0 +1,239 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import os +import gym +import six +import queue +import parl +import time +import threading +import numpy as np + +from collections import defaultdict +from parl.env.atari_wrappers import wrap_deepmind +from parl.utils.window_stat import WindowStat +from parl.utils.time_stat import TimeStat +from parl.utils import machine_info +from parl.utils import logger, get_gpu_count, tensorboard +from parl.algorithms import A2C + +from atari_model import ActorCritic +from atari_agent import Agent +from actor import Actor + +import time +from statistics import mean + + +class Learner(object): + def __init__(self, config, cuda): + self.cuda = cuda + + self.config = config + env = gym.make(config['env_name']) + env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') + obs_shape = env.observation_space.shape + act_dim = env.action_space.n + self.config['obs_shape'] = obs_shape + self.config['act_dim'] = act_dim + + model = ActorCritic(act_dim) + if self.cuda: + model = model.cuda() + + algorithm = A2C(model, config) + self.agent = Agent(algorithm, config) + + if machine_info.is_gpu_available(): + assert get_gpu_count() == 1, 'Only support training in single GPU,\ + Please set environment variable: `export CUDA_VISIBLE_DEVICES=[GPU_ID_YOU_WANT_TO_USE]` .' + + else: + os.environ['CPU_NUM'] = str(1) + + #========== Learner ========== + self.total_loss_stat = WindowStat(100) + self.pi_loss_stat = WindowStat(100) + self.vf_loss_stat = WindowStat(100) + self.entropy_stat = WindowStat(100) + self.lr = None + self.entropy_coeff = None + + self.learn_time_stat = TimeStat(100) + self.start_time = None + + #========== Remote Actor =========== + self.remote_count = 0 + self.sample_total_steps = 0 + self.sample_data_queue = queue.Queue() + self.remote_metrics_queue = queue.Queue() + self.params_queues = [] + + self.create_actors() + + def create_actors(self): + parl.connect(self.config['master_address']) + + logger.info('Waiting for {} remote actors to connect.'.format( + self.config['actor_num'])) + + for i in six.moves.range(self.config['actor_num']): + params_queue = queue.Queue() + self.params_queues.append(params_queue) + + self.remote_count += 1 + logger.info('Remote actor count: {}'.format(self.remote_count)) + + remote_thread = threading.Thread( + target=self.run_remote_sample, args=(params_queue, )) + remote_thread.setDaemon(True) + remote_thread.start() + + logger.info('All remote actors are ready, begin to learn.') + self.start_time = time.time() + + def run_remote_sample(self, params_queue): + remote_actor = Actor(self.config) + + cnt = 0 + while True: + latest_params = params_queue.get() + + remote_actor.set_weights(latest_params) + batch = remote_actor.sample() + self.sample_data_queue.put(batch) + + cnt += 1 + if cnt % self.config['get_remote_metrics_interval'] == 0: + metrics = remote_actor.get_metrics() + if metrics: + self.remote_metrics_queue.put(metrics) + + def step(self): + latest_params = self.agent.get_weights() + + for params_queue in self.params_queues: + params_queue.put(latest_params) + + train_batch = defaultdict(list) + for i in range(self.config['actor_num']): + sample_data = self.sample_data_queue.get() + for key, value in sample_data.items(): + train_batch[key].append(value) + self.sample_total_steps += len(sample_data['obs']) + + for key, value in train_batch.items(): + train_batch[key] = np.concatenate(value) + train_batch[key] = torch.tensor(train_batch[key]).float() + if self.cuda: + train_batch[key] = train_batch[key].cuda() + + with self.learn_time_stat: + total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff = self.agent.learn( + obs=train_batch['obs'], + actions=train_batch['actions'], + advantages=train_batch['advantages'], + target_values=train_batch['target_values'], + ) + + self.total_loss_stat.add(total_loss.item()) + self.pi_loss_stat.add(pi_loss.item()) + self.vf_loss_stat.add(vf_loss.item()) + self.entropy_stat.add(entropy.item()) + self.lr = lr + self.entropy_coeff = entropy_coeff + + def log_metrics(self): + """ Log metrics of learner and actors + """ + if self.start_time is None: + return + + metrics = [] + while True: + try: + metric = self.remote_metrics_queue.get_nowait() + metrics.append(metric) + except queue.Empty: + break + + episode_rewards, episode_steps = [], [] + for x in metrics: + episode_rewards.extend(x['episode_rewards']) + episode_steps.extend(x['episode_steps']) + max_episode_rewards, mean_episode_rewards, min_episode_rewards, \ + max_episode_steps, mean_episode_steps, min_episode_steps =\ + None, None, None, None, None, None + if episode_rewards: + mean_episode_rewards = np.mean(np.array(episode_rewards).flatten()) + max_episode_rewards = np.max(np.array(episode_rewards).flatten()) + min_episode_rewards = np.min(np.array(episode_rewards).flatten()) + + mean_episode_steps = np.mean(np.array(episode_steps).flatten()) + max_episode_steps = np.max(np.array(episode_steps).flatten()) + min_episode_steps = np.min(np.array(episode_steps).flatten()) + + metric = { + 'Sample steps': self.sample_total_steps, + 'max_episode_rewards': max_episode_rewards, + 'mean_episode_rewards': mean_episode_rewards, + 'min_episode_rewards': min_episode_rewards, + 'max_episode_steps': max_episode_steps, + 'mean_episode_steps': mean_episode_steps, + 'min_episode_steps': min_episode_steps, + 'total_loss': self.total_loss_stat.mean, + 'pi_loss': self.pi_loss_stat.mean, + 'vf_loss': self.vf_loss_stat.mean, + 'entropy': self.entropy_stat.mean, + 'learn_time_s': self.learn_time_stat.mean, + 'elapsed_time_s': int(time.time() - self.start_time), + 'lr': self.lr, + 'entropy_coeff': self.entropy_coeff, + } + + if metric['mean_episode_rewards'] is not None: + tensorboard.add_scalar('train/mean_reward', + metric['mean_episode_rewards'], + self.sample_total_steps) + tensorboard.add_scalar('train/total_loss', metric['total_loss'], + self.sample_total_steps) + tensorboard.add_scalar('train/pi_loss', metric['pi_loss'], + self.sample_total_steps) + tensorboard.add_scalar('train/vf_loss', metric['vf_loss'], + self.sample_total_steps) + tensorboard.add_scalar('train/entropy', metric['entropy'], + self.sample_total_steps) + tensorboard.add_scalar('train/learn_rate', metric['lr'], + self.sample_total_steps) + + logger.info(metric) + + def should_stop(self): + return self.sample_total_steps >= self.config['max_sample_steps'] + + +if __name__ == '__main__': + from a2c_config import config + + cuda = torch.cuda.is_available() + learner = Learner(config, cuda) + assert config['log_metrics_interval_s'] > 0 + + while not learner.should_stop(): + start = time.time() + while time.time() - start < config['log_metrics_interval_s']: + learner.step() + learner.log_metrics() diff --git a/benchmark/torch/dqn/agent.py b/benchmark/torch/dqn/agent.py index 95f383aa632db7bd4c7ddd5ceb11348913cb7ca0..5a145e2750f243e0d8049d9a3c67bbd06401c338 100644 --- a/benchmark/torch/dqn/agent.py +++ b/benchmark/torch/dqn/agent.py @@ -22,10 +22,10 @@ import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -from parl.core.torch.agent import Agent +import parl -class AtariAgent(Agent): +class AtariAgent(parl.Agent): """Base class of the Agent. Args: diff --git a/benchmark/torch/dqn/model.py b/benchmark/torch/dqn/model.py index 8ba80d5dc2cfdb2d758f3f7ac895ea641e9fa62e..fa3ea2e600b94900a6d589c56f8da1866967cc2b 100644 --- a/benchmark/torch/dqn/model.py +++ b/benchmark/torch/dqn/model.py @@ -16,10 +16,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from parl.core.torch.model import Model +import parl -class AtariModel(Model): +class AtariModel(parl.Model): """CNN network used in TensorPack examples. Args: diff --git a/parl/algorithms/torch/__init__.py b/parl/algorithms/torch/__init__.py index abc70cdc60c04e5ccdb1f161443d1f855a7109f9..7f026bf089db0b8b8dcb1d73a3fb83657509d080 100644 --- a/parl/algorithms/torch/__init__.py +++ b/parl/algorithms/torch/__init__.py @@ -14,3 +14,4 @@ from parl.algorithms.torch.ddqn import * from parl.algorithms.torch.dqn import * +from parl.algorithms.torch.a2c import * diff --git a/parl/algorithms/torch/a2c.py b/parl/algorithms/torch/a2c.py new file mode 100644 index 0000000000000000000000000000000000000000..3d78ce75938c583e15e4f7321ad836d869ef25b1 --- /dev/null +++ b/parl/algorithms/torch/a2c.py @@ -0,0 +1,89 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.distributions import Categorical +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.optim import lr_scheduler +from random import random, randint + +import parl +from parl.utils.scheduler import PiecewiseScheduler, LinearDecayScheduler + +__all__ = ['A2C'] + + +class A2C(parl.Algorithm): + def __init__(self, model, config, hyperparas=None): + assert isinstance(config['vf_loss_coeff'], (int, float)) + self.model = model + self.vf_loss_coeff = config['vf_loss_coeff'] + self.optimizer = optim.Adam( + self.model.parameters(), lr=config['learning_rate']) + self.config = config + + self.lr_scheduler = LinearDecayScheduler(config['start_lr'], + config['max_sample_steps']) + + self.entropy_coeff_scheduler = PiecewiseScheduler( + config['entropy_coeff_scheduler']) + + def learn(self, obs, actions, advantages, target_values): + prob = self.model.policy(obs, softmax_dim=1) + policy_distri = Categorical(prob) + actions_log_probs = policy_distri.log_prob(actions) + + # The policy gradient loss + pi_loss = -((actions_log_probs * advantages).sum()) + + # The value function loss + values = self.model.value(obs).reshape(-1) + delta = values - target_values + vf_loss = 0.5 * torch.mul(delta, delta).sum() + + # The entropy loss (We want to maximize entropy, so entropy_ceoff < 0) + policy_entropy = policy_distri.entropy() + entropy = policy_entropy.sum() + + lr = self.lr_scheduler.step(step_num=obs.shape[0]) + entropy_coeff = self.entropy_coeff_scheduler.step() + + total_loss = pi_loss + vf_loss * self.vf_loss_coeff + entropy * entropy_coeff + + for param_group in self.optimizer.param_groups: + param_group['lr'] = lr + + total_loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + return total_loss, pi_loss, vf_loss, entropy, lr, entropy_coeff + + def sample(self, obs): + prob, values = self.model.policy_and_value(obs) + sample_actions = Categorical(prob).sample() + + return sample_actions, values + + def predict(self, obs): + prob = self.model.policy(obs) + _, predict_actions = prob.max(-1) + + return predict_actions + + def value(self, obs): + values = self.model.value(obs) + return values diff --git a/parl/algorithms/torch/ddqn.py b/parl/algorithms/torch/ddqn.py index 9b6e271199e3d26b723b9bd71a285f77db415c5e..5a1417676584e6d6a723fd06e75625da40eda6c4 100644 --- a/parl/algorithms/torch/ddqn.py +++ b/parl/algorithms/torch/ddqn.py @@ -19,13 +19,13 @@ import copy import torch import torch.optim as optim import torch.nn.functional as F -from parl.core.torch.algorithm import Algorithm +import parl import numpy as np __all__ = ['DDQN'] -class DDQN(Algorithm): +class DDQN(parl.Algorithm): def __init__(self, model, gamma=None, lr=None): """ Double DQN algorithm diff --git a/parl/algorithms/torch/dqn.py b/parl/algorithms/torch/dqn.py index 9244f5d6f255d522126fbb5998ee20191fb930ae..040262f81c11ba787c02b8a5dafe803d052326b8 100644 --- a/parl/algorithms/torch/dqn.py +++ b/parl/algorithms/torch/dqn.py @@ -19,13 +19,13 @@ import copy import torch import torch.optim as optim import torch.nn.functional as F -from parl.core.torch.algorithm import Algorithm +import parl import numpy as np __all__ = ['DQN'] -class DQN(Algorithm): +class DQN(parl.Algorithm): def __init__(self, model, gamma=None, lr=None): """ DQN algorithm diff --git a/parl/core/algorithm_base.py b/parl/core/algorithm_base.py index e5e5d80e4bf4cd9e8bcb01df9e514255c9620b44..e2f974b40e31c7d492076cf5479bb44d71e5291a 100644 --- a/parl/core/algorithm_base.py +++ b/parl/core/algorithm_base.py @@ -39,7 +39,7 @@ class AlgorithmBase(object): Args: model_ids (List/Set): list/set of model_id, will only return weights of models - whiose model_id in the `model_ids`. + whose model_id in the `model_ids`. Returns: Dict of weights ({attribute name: numpy array/List/Dict}) diff --git a/parl/core/fluid/tests/model_base_test_.py b/parl/core/fluid/tests/model_base_test_.py index 82adbfbe3eb3fe5ce339c6ccef7bac697a795055..1656366a2fd97daf019b2cfb42f1ab7be640a65a 100644 --- a/parl/core/fluid/tests/model_base_test_.py +++ b/parl/core/fluid/tests/model_base_test_.py @@ -667,13 +667,8 @@ class ModelBaseTest(unittest.TestCase): params = self.model.get_weights() - try: + with self.assertRaises(AssertionError): self.model.set_weights(params[1:]) - except: - # expected - return - - assert False def test_set_weights_with_wrong_params_shape(self): pred_program = fluid.Program() @@ -691,14 +686,9 @@ class ModelBaseTest(unittest.TestCase): x = np.random.random(size=(1, 4)).astype('float32') - try: - outputs = self.executor.run( + with self.assertRaises(fluid.core_avx.EnforceNotMet): + self.executor.run( pred_program, feed={'obs': x}, fetch_list=[model_output]) - except: - # expected - return - - assert False if __name__ == '__main__': diff --git a/parl/core/torch/agent.py b/parl/core/torch/agent.py index 7e2ef38ceb9eb3e60b600b2e4c5b8dc3abbd4a17..5d8bb2195dc0fdff48c2e9a3d5f477b793af06eb 100644 --- a/parl/core/torch/agent.py +++ b/parl/core/torch/agent.py @@ -52,7 +52,7 @@ class Agent(AgentBase): Public Functions: - ``sample``: return a noisy action to perform exploration according to the policy. - ``predict``: return an estimate Q function given current observation. - - ``learn``: update the parameters of self.alg. + - ``learn``: update the parameters of self.algorithm. - ``save``: save parameters of the ``agent`` to a given path. - ``restore``: restore previous saved parameters from a given path. @@ -60,21 +60,17 @@ class Agent(AgentBase): - allow users to get parameters of a specified model by specifying the model's name in ``get_weights()``. """ - def __init__(self, algorithm, device): + def __init__(self, algorithm): """. Args: - algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.alg`. + algorithm (parl.Algorithm): an instance of `parl.Algorithm`. This algorithm is then passed to `self.algorithm`. device (torch.device): specify which GPU/CPU to be used. """ assert isinstance(algorithm, Algorithm) super(Agent, self).__init__(algorithm) - self.alg = algorithm - self.device = torc.device('cuda' if torch.cuda. - is_available() else 'cpu') - def learn(self, *args, **kwargs): """The training interface for ``Agent``. @@ -102,10 +98,10 @@ class Agent(AgentBase): Args: save_path(str): where to save the parameters. - model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. + model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model. Raises: - ValueError: if model is None and self.alg.model does not exist. + ValueError: if model is None and self.algorithm.model does not exist. Example: @@ -116,7 +112,7 @@ class Agent(AgentBase): """ if model is None: - model = self.alg.model + model = self.algorithm.model dirname = '/'.join(save_path.split('/')[:-1]) if not os.path.exists(dirname): os.makedirs(dirname) @@ -129,10 +125,10 @@ class Agent(AgentBase): Args: save_path(str): path where parameters were previously saved. - model(parl.Model): model that describes the neural network structure. If None, will use self.alg.model. + model(parl.Model): model that describes the neural network structure. If None, will use self.algorithm.model. Raises: - ValueError: if model is None and self.alg does not exist. + ValueError: if model is None and self.algorithm does not exist. Example: @@ -145,6 +141,6 @@ class Agent(AgentBase): """ if model is None: - model = self.alg.model + model = self.algorithm.model checkpoint = torch.load(save_path) model.load_state_dict(checkpoint) diff --git a/parl/core/torch/algorithm.py b/parl/core/torch/algorithm.py index d95368893b6394902ecdd4eebe0af8270da8ce18..e019296e8945e3e09e3faba52a8add812eb22ce6 100644 --- a/parl/core/torch/algorithm.py +++ b/parl/core/torch/algorithm.py @@ -62,7 +62,7 @@ class Algorithm(AlgorithmBase): assert isinstance(model, Model) self.model = model - def get_weights(self): + def get_weights(self, model_ids=None): """ Get weights of self.model. Returns: @@ -71,7 +71,7 @@ class Algorithm(AlgorithmBase): """ return self.model.get_weights() - def set_weights(self, params): + def set_weights(self, params, model_ids=None): """ Set weights from ``get_weights`` to the model. Args: diff --git a/parl/core/torch/model.py b/parl/core/torch/model.py index 4827cfad186aa8a15bc5530306fd777bdb44a0f4..86f6c1a12f0c3bdcc45d80edfee2139dd90ab995 100644 --- a/parl/core/torch/model.py +++ b/parl/core/torch/model.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import torch import torch.nn as nn from parl.core.model_base import ModelBase @@ -116,7 +117,10 @@ class Model(nn.Module, ModelBase): Returns: a Python list containing the parameters of current model. """ - return list(self.parameters()) + weights = self.state_dict() + for key in weights.keys(): + weights[key] = weights[key].cpu().numpy() + return weights def set_weights(self, weights): """Copy parameters from ``set_weights()`` to the model. @@ -124,8 +128,6 @@ class Model(nn.Module, ModelBase): Args: weights (list): a Python list containing the parameters. """ - assert len(weights) == len(list(self.parameters())), \ - 'size of input weights should be same as weights number of current model' - - for var, weight in zip(self.parameters(), weights): - var.data.copy_(weight.data) + for key in weights.keys(): + weights[key] = torch.from_numpy(weights[key]) + self.load_state_dict(weights) diff --git a/parl/core/torch/tests/agent_base_test_torch.py b/parl/core/torch/tests/agent_base_test_torch.py index c39c9e5e0c6dfb24ca13fc11b22faeb99bdbe013..96caf7532c38bafea6ba33d41ecb173361c525ac 100644 --- a/parl/core/torch/tests/agent_base_test_torch.py +++ b/parl/core/torch/tests/agent_base_test_torch.py @@ -37,7 +37,7 @@ class TestModel(parl.Model): class TestAlgorithm(parl.Algorithm): def __init__(self, model): - self.model = model + super(TestAlgorithm, self).__init__(model) self.optimizer = optim.Adam(self.model.parameters(), lr=0.001) def predict(self, obs): @@ -54,13 +54,13 @@ class TestAlgorithm(parl.Algorithm): class TestAgent(parl.Agent): def __init__(self, algorithm): - self.alg = algorithm + super(TestAgent, self).__init__(algorithm) def learn(self, obs, label): - cost = self.alg.learn(obs, label) + cost = self.algorithm.learn(obs, label) def predict(self, obs): - return self.alg.predict(obs) + return self.algorithm.predict(obs) class AgentBaseTest(unittest.TestCase): @@ -95,6 +95,11 @@ class AgentBaseTest(unittest.TestCase): current_output = agent.predict(obs).detach().cpu().numpy() np.testing.assert_equal(current_output, previous_output) + def test_weights(self): + agent = TestAgent(self.alg) + weight = agent.get_weights() + agent.set_weights(weight) + if __name__ == '__main__': unittest.main() diff --git a/parl/core/torch/tests/model_base_test_torch.py b/parl/core/torch/tests/model_base_test_torch.py index dae18e0b04612cd63ec8ceb1bec45b3c6bfbea0f..c8047a61e772321d480dc8ffa01bffa102a53265 100644 --- a/parl/core/torch/tests/model_base_test_torch.py +++ b/parl/core/torch/tests/model_base_test_torch.py @@ -16,6 +16,7 @@ import numpy as np import unittest import os from copy import deepcopy +from collections import OrderedDict import torch import torch.nn as nn @@ -44,6 +45,7 @@ class ModelBaseTest(unittest.TestCase): self.model = TestModel() self.target_model = TestModel() self.target_model2 = TestModel() + self.target_model3 = TestModel() gpu_count = get_gpu_count() device = torch.device('cuda' if gpu_count else 'cpu') @@ -282,22 +284,18 @@ class ModelBaseTest(unittest.TestCase): params = self.model.get_weights() expected_params = list(self.model.parameters()) self.assertEqual(len(params), len(expected_params)) - for param in params: - flag = False - for expected_param in expected_params: - if param.sum().item() - expected_param.sum().item() < 1e-5: - flag = True - break - self.assertTrue(flag) + for i, key in enumerate(params): + self.assertLess( + (params[key].sum().item() - expected_params[i].sum().item()), + 1e-5) def test_set_weights(self): params = self.model.get_weights() - new_params = [x + 1.0 for x in params] + self.target_model3.set_weights(params) - self.model.set_weights(new_params) - - for x, y in list(zip(new_params, self.model.get_weights())): - self.assertEqual(x.sum().item(), y.sum().item()) + for i, j in zip(params.values(), + self.target_model3.get_weights().values()): + self.assertLessEqual(abs(i.sum().item() - j.sum().item()), 1e-3) def test_set_weights_between_different_models(self): model1 = TestModel() @@ -323,20 +321,14 @@ class ModelBaseTest(unittest.TestCase): def test_set_weights_wrong_params_num(self): params = self.model.get_weights() - try: + with self.assertRaises(TypeError): self.model.set_weights(params[1:]) - except: - return - assert False def test_set_weights_wrong_params_shape(self): params = self.model.get_weights() - params.reverse() - try: + params['fc1.weight'] = params['fc2.bias'] + with self.assertRaises(RuntimeError): self.model.set_weights(params) - except: - return - assert False if __name__ == '__main__': diff --git a/parl/utils/tests/scheduler_test.py b/parl/utils/tests/scheduler_test.py index a95dfcbeeee6ba5e2a88786bacb5bcb48977b1ad..dc1d1e1bed1f2a5a51559517e670a9e8bcf662d1 100644 --- a/parl/utils/tests/scheduler_test.py +++ b/parl/utils/tests/scheduler_test.py @@ -53,20 +53,12 @@ class TestScheduler(unittest.TestCase): assert value == 0.3 def test_PiecewiseScheduler_with_empty(self): - try: + with self.assertRaises(AssertionError): scheduler = PiecewiseScheduler([]) - except AssertionError: - # expected - return - assert False def test_PiecewiseScheduler_with_incorrect_steps(self): - try: - scheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)]) - except AssertionError: - # expected - return - assert False + with self.assertRaises(AssertionError): + tscheduler = PiecewiseScheduler([(10, 0.1), (1, 0.2)]) def test_LinearDecayScheduler(self): scheduler = LinearDecayScheduler(start_value=10, max_steps=10)