From 3f4caf56364be1cebbdd2435ae5cef18b6a55644 Mon Sep 17 00:00:00 2001 From: Varuna Jayasiri Date: Sat, 27 Mar 2021 11:34:45 +0530 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20ppo=20configs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- labml_nn/rl/ppo/experiment.py | 95 ++++++++++++++++++++++++++--------- 1 file changed, 71 insertions(+), 24 deletions(-) diff --git a/labml_nn/rl/ppo/experiment.py b/labml_nn/rl/ppo/experiment.py index 4f846089..74288167 100644 --- a/labml_nn/rl/ppo/experiment.py +++ b/labml_nn/rl/ppo/experiment.py @@ -19,6 +19,7 @@ from torch import optim from torch.distributions import Categorical from labml import monit, tracker, logger, experiment +from labml.internal.configs.dynamic_hyperparam import FloatDynamicHyperParam from labml_helpers.module import Module from labml_nn.rl.game import Worker from labml_nn.rl.ppo import ClippedPPOLoss, ClippedValueFunctionLoss @@ -89,24 +90,40 @@ class Trainer: ## Trainer """ - def __init__(self): + def __init__(self, *, + updates: int, epochs: int, n_workers: int, worker_steps: int, batches: int, + value_loss_coef: FloatDynamicHyperParam, + entropy_bonus_coef: FloatDynamicHyperParam, + clip_range: FloatDynamicHyperParam, + learning_rate: FloatDynamicHyperParam, + ): # #### Configurations # number of updates - self.updates = 10000 + self.updates = updates # number of epochs to train the model with sampled data - self.epochs = 4 + self.epochs = epochs # number of worker processes - self.n_workers = 8 + self.n_workers = n_workers # number of steps to run on each process for a single update - self.worker_steps = 128 + self.worker_steps = worker_steps # number of mini batches - self.n_mini_batch = 4 + self.batches = batches # total number of samples for a single update self.batch_size = self.n_workers * self.worker_steps # size of a mini batch - self.mini_batch_size = self.batch_size // self.n_mini_batch - assert (self.batch_size % self.n_mini_batch == 0) + self.mini_batch_size = self.batch_size // self.batches + assert (self.batch_size % self.batches == 0) + + # Value loss coefficient + self.value_loss_coef = value_loss_coef + # Entropy bonus coefficient + self.entropy_bonus_coef = entropy_bonus_coef + + # Clipping range + self.clip_range = clip_range + # Learning rate + self.learning_rate = learning_rate # #### Initialize @@ -204,7 +221,7 @@ class Trainer: return samples_flat - def train(self, samples: Dict[str, torch.Tensor], learning_rate: float, clip_range: float): + def train(self, samples: Dict[str, torch.Tensor]): """ ### Train the model based on samples """ @@ -228,12 +245,11 @@ class Trainer: mini_batch[k] = v[mini_batch_indexes] # train - loss = self._calc_loss(clip_range=clip_range, - samples=mini_batch) + loss = self._calc_loss(mini_batch) # Set learning rate for pg in self.optimizer.param_groups: - pg['lr'] = learning_rate + pg['lr'] = self.learning_rate() # Zero out the previously calculated gradients self.optimizer.zero_grad() # Calculate gradients @@ -248,7 +264,7 @@ class Trainer: """#### Normalize advantage function""" return (adv - adv.mean()) / (adv.std() + 1e-8) - def _calc_loss(self, samples: Dict[str, torch.Tensor], clip_range: float) -> torch.Tensor: + def _calc_loss(self, samples: Dict[str, torch.Tensor]) -> torch.Tensor: """ ### Calculate total loss """ @@ -270,7 +286,7 @@ class Trainer: log_pi = pi.log_prob(samples['actions']) # Calculate policy loss - policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, clip_range) + policy_loss = self.ppo_loss(log_pi, samples['log_pis'], sampled_normalized_advantage, self.clip_range()) # Calculate Entropy Bonus # @@ -280,12 +296,14 @@ class Trainer: entropy_bonus = entropy_bonus.mean() # Calculate value function loss - value_loss = self.value_loss(value, samples['values'], sampled_return, clip_range) + value_loss = self.value_loss(value, samples['values'], sampled_return, self.clip_range()) # $\mathcal{L}^{CLIP+VF+EB} (\theta) = # \mathcal{L}^{CLIP} (\theta) + # c_1 \mathcal{L}^{VF} (\theta) - c_2 \mathcal{L}^{EB}(\theta)$ - loss = policy_loss + 0.5 * value_loss - 0.01 * entropy_bonus + loss = (policy_loss + + self.value_loss_coef() * value_loss + - self.entropy_bonus_coef() * entropy_bonus) # for monitoring approx_kl_divergence = .5 * ((samples['log_pis'] - log_pi) ** 2).mean() @@ -309,17 +327,11 @@ class Trainer: tracker.set_queue('length', 100, True) for update in monit.loop(self.updates): - progress = update / self.updates - - # decreasing `learning_rate` and `clip_range` $\epsilon$ - learning_rate = 2.5e-4 * (1 - progress) - clip_range = 0.1 * (1 - progress) - # sample with current policy samples = self.sample() # train the model - self.train(samples, learning_rate, clip_range) + self.train(samples) # Save tracked indicators. tracker.save() @@ -339,8 +351,43 @@ class Trainer: def main(): # Create the experiment experiment.create(name='ppo') + # Configurations + configs = { + # number of updates + 'updates': 10000, + # number of epochs to train the model with sampled data + 'epochs': 4, + # number of worker processes + 'n_workers': 8, + # number of steps to run on each process for a single update + 'worker_steps': 128, + # number of mini batches + 'batches': 4, + # Value loss coefficient + 'value_loss_coef': FloatDynamicHyperParam(0.5), + # Entropy bonus coefficient + 'entropy_bonus_coef': FloatDynamicHyperParam(0.01), + # Clip range + 'clip_range': FloatDynamicHyperParam(0.1), + # Learning rate + 'learning_rate': FloatDynamicHyperParam(2.5e-4, (0, 1e-3)), + } + + experiment.configs(configs) + # Initialize the trainer - m = Trainer() + m = Trainer( + updates=configs['updates'], + epochs=configs['epochs'], + n_workers=configs['n_workers'], + worker_steps=configs['worker_steps'], + batches=configs['batches'], + value_loss_coef=configs['value_loss_coef'], + entropy_bonus_coef=configs['entropy_bonus_coef'], + clip_range=configs['clip_range'], + learning_rate=configs['learning_rate'], + ) + # Run and monitor the experiment with experiment.start(): m.run_training_loop() -- GitLab