diff --git a/examples/A2C/README.md b/examples/A2C/README.md old mode 100644 new mode 100755 index 9e54b1ae068ce2447563d0a69b82dc7591facae5..195a2af72708d9699d2d5e0b3baf191d44598b4a --- a/examples/A2C/README.md +++ b/examples/A2C/README.md @@ -24,22 +24,23 @@ Mean episode reward in training process after 10 million sample steps. ### Distributed Training -#### Learner -```sh -python train.py -``` +At first, We can start a local cluster with 5 CPUs: -#### Actors (Suggest: 5 actors in 5 CPUs) -```sh -for i in $(seq 1 5); do - python actor.py & -done; -wait +```bash +xparl start --port 8010 --cpu_num 5 ``` -You can change training settings (e.g. `env_name`, `server_ip`) in `a2c_config.py`. -Training result will be saved in `log_dir/train/result.csv`. +Note that if you have started a master before, you don't have to run the above +command. For more information about the cluster, please refer to our +[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) + +Then we can start the distributed training by running `train.py`. + +```bash +python train.py +``` ### Reference ++ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) + [Ray](https://github.com/ray-project/ray) + [OpenAI Baselines: ACKTR & A2C](https://openai.com/blog/baselines-acktr-a2c/) diff --git a/examples/A2C/a2c_config.py b/examples/A2C/a2c_config.py old mode 100644 new mode 100755 index 0fe0fd1cc20b833fba33df7617404989956890e3..fac9a61c673e48b9ffc135bfb2859fdefad9f644 --- a/examples/A2C/a2c_config.py +++ b/examples/A2C/a2c_config.py @@ -13,9 +13,9 @@ # limitations under the License. config = { + #========== remote config ========== - 'server_ip': 'localhost', - 'server_port': 8037, + 'master_address': 'localhost:8010', #========== env config ========== 'env_name': 'PongNoFrameskip-v4', diff --git a/examples/A2C/actor.py b/examples/A2C/actor.py old mode 100644 new mode 100755 index e85fbc00cc4a0ca962dd4315525e32f99fc68c61..f9cef1e540b44f132ded0c4e891e14b5b2d26170 --- a/examples/A2C/actor.py +++ b/examples/A2C/actor.py @@ -15,8 +15,6 @@ import gym import numpy as np import parl -import six -import parl from atari_model import AtariModel from collections import defaultdict from atari_agent import AtariAgent @@ -31,7 +29,7 @@ class Actor(object): self.config = config self.envs = [] - for _ in six.moves.range(config['env_num']): + for _ in range(config['env_num']): env = gym.make(config['env_name']) env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') self.envs.append(env) @@ -54,16 +52,16 @@ class Actor(object): sample_data = defaultdict(list) env_sample_data = {} - for env_id in six.moves.range(self.config['env_num']): + for env_id in range(self.config['env_num']): env_sample_data[env_id] = defaultdict(list) - for i in six.moves.range(self.config['sample_batch_steps']): + for i in range(self.config['sample_batch_steps']): actions_batch, values_batch = self.agent.sample( np.stack(self.obs_batch)) next_obs_batch, reward_batch, done_batch, info_batch = \ self.vector_env.step(actions_batch) - for env_id in six.moves.range(self.config['env_num']): + for env_id in range(self.config['env_num']): env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) env_sample_data[env_id]['actions'].append( actions_batch[env_id]) @@ -115,10 +113,3 @@ class Actor(object): def set_weights(self, params): self.agent.set_weights(params) - - -if __name__ == '__main__': - from a2c_config import config - - actor = Actor(config) - actor.as_remote(config['server_ip'], config['server_port']) diff --git a/examples/A2C/atari_agent.py b/examples/A2C/atari_agent.py old mode 100644 new mode 100755 diff --git a/examples/A2C/atari_model.py b/examples/A2C/atari_model.py old mode 100644 new mode 100755 diff --git a/examples/A2C/learner.py b/examples/A2C/learner.py old mode 100644 new mode 100755 index 49a77ddfb4d3da5319417e6349f132c1304ee042..6624be44be9365e69dc28be492e38156272cd426 --- a/examples/A2C/learner.py +++ b/examples/A2C/learner.py @@ -23,14 +23,16 @@ import parl from atari_model import AtariModel from atari_agent import AtariAgent from collections import defaultdict -from parl import RemoteManager + from parl.env.atari_wrappers import wrap_deepmind -from parl.utils import logger, CSVLogger, get_gpu_count +from parl.utils import logger, get_gpu_count, tensorboard from parl.utils.scheduler import PiecewiseScheduler from parl.utils.time_stat import TimeStat from parl.utils.window_stat import WindowStat from parl.utils import machine_info +from actor import Actor + class Learner(object): def __init__(self, config): @@ -78,20 +80,17 @@ class Learner(object): self.sample_total_steps = 0 self.params_queues = [] - self.run_remote_manager() + self.create_actors() - self.csv_logger = CSVLogger( - os.path.join(logger.get_dir(), 'result.csv')) - - def run_remote_manager(self): - """ Accept connection of new remote actor and start sampling of the remote actor. + def create_actors(self): + """ Connect to the cluster and start sampling of the remote actor. """ - remote_manager = RemoteManager(port=self.config['server_port']) + parl.connect(self.config['master_address']) + logger.info('Waiting for {} remote actors to connect.'.format( self.config['actor_num'])) for i in six.moves.range(self.config['actor_num']): - remote_actor = remote_manager.get_remote() params_queue = queue.Queue() self.params_queues.append(params_queue) @@ -99,17 +98,18 @@ class Learner(object): logger.info('Remote actor count: {}'.format(self.remote_count)) remote_thread = threading.Thread( - target=self.run_remote_sample, - args=(remote_actor, params_queue)) + target=self.run_remote_sample, args=(params_queue, )) remote_thread.setDaemon(True) remote_thread.start() logger.info('All remote actors are ready, begin to learn.') self.start_time = time.time() - def run_remote_sample(self, remote_actor, params_queue): + def run_remote_sample(self, params_queue): """ Sample data from remote actor and update parameters of remote actor. """ + remote_actor = Actor(self.config) + cnt = 0 while True: latest_params = params_queue.get() @@ -128,7 +128,7 @@ class Learner(object): """ 1. kick off all actors to synchronize parameters and sample data; 2. collect sample data of all actors; - 3. update parameters. + 3. update parameters. """ latest_params = self.agent.get_weights() @@ -208,11 +208,11 @@ class Learner(object): 'entropy_coeff': self.entropy_coeff, } + for key, value in metric.items(): + if value is not None: + tensorboard.add_scalar(key, value, self.sample_total_steps) + logger.info(metric) - self.csv_logger.log_dict(metric) def should_stop(self): return self.sample_total_steps >= self.config['max_sample_steps'] - - def close(self): - self.csv_logger.close() diff --git a/examples/A2C/run_actors.sh b/examples/A2C/run_actors.sh deleted file mode 100644 index 13518376cd8a1c041fde6d58270e5325ddf946e0..0000000000000000000000000000000000000000 --- a/examples/A2C/run_actors.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -export CUDA_VISIBLE_DEVICES="" - -for i in $(seq 1 5); do - python actor.py & -done; -wait diff --git a/examples/A2C/train.py b/examples/A2C/train.py old mode 100644 new mode 100755 index 8fd25fe35656f6c15cefb6efb2e5fed39e5833ac..e100e46677a5b5ca4d0b905587e800754b62d613 --- a/examples/A2C/train.py +++ b/examples/A2C/train.py @@ -14,22 +14,18 @@ import time from learner import Learner +import parl def main(config): learner = Learner(config) assert config['log_metrics_interval_s'] > 0 - try: - while not learner.should_stop(): - start = time.time() - while time.time() - start < config['log_metrics_interval_s']: - learner.step() - learner.log_metrics() - learner.close() - - except KeyboardInterrupt: - learner.close() + while not learner.should_stop(): + start = time.time() + while time.time() - start < config['log_metrics_interval_s']: + learner.step() + learner.log_metrics() if __name__ == '__main__': diff --git a/examples/GA3C/README.md b/examples/GA3C/README.md old mode 100644 new mode 100755 index 198cfa0087e567a7cfc9b444fb4a2db0b3e02abd..2a6c1cbfd7081c85a94afe423d62ef85507dcee8 --- a/examples/GA3C/README.md +++ b/examples/GA3C/README.md @@ -21,26 +21,28 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill + gym + atari-py - ### Distributed Training -#### Learner -```sh -python train.py -``` +At first, We can start a local cluster with 24 CPUs: -#### Simulators (Suggest: 24 simulators in 12+ CPUs) -```sh -for i in $(seq 1 24); do - python simulator.py & -done; -wait +```bash +xparl start --port 8010 --cpu_num 24 ``` -You can change training settings (e.g. `env_name`, `server_ip`) in `ga3c_config.py`. -Training result will be saved in `log_dir/train/result.csv`. +Note that if you have started a master before, you don't have to run the above +command. For more information about the cluster, please refer to our +[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) + +Then we can start the distributed training by running `train.py`. + +```bash +python train.py +``` -[Tips] The performance can be influenced dramatically in a slower computational environment, especially when training with low-speed CPUs. It may be caused by the policy-lag problem. +[Tips] The performance can be influenced dramatically in a slower computational +environment, especially when training with low-speed CPUs. It may be caused by +the policy-lag problem. ### Reference ++ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) + [tensorpack](https://github.com/tensorpack/tensorpack) diff --git a/examples/GA3C/simulator.py b/examples/GA3C/actor.py old mode 100644 new mode 100755 similarity index 88% rename from examples/GA3C/simulator.py rename to examples/GA3C/actor.py index 2019460e05272026f2e4a41c4729dbb38cfd2c68..87e8e819993ae3002b926697f447059be2564432 --- a/examples/GA3C/simulator.py +++ b/examples/GA3C/actor.py @@ -15,13 +15,12 @@ import gym import numpy as np import parl -import six from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls from collections import defaultdict @parl.remote_class -class Simulator(object): +class Actor(object): def __init__(self, config): self.config = config @@ -45,10 +44,3 @@ class Simulator(object): metrics['episode_rewards'].append(episode_rewards) metrics['episode_steps'].append(episode_steps) return metrics - - -if __name__ == '__main__': - from ga3c_config import config - - simulator = Simulator(config) - simulator.as_remote(config['server_ip'], config['server_port']) diff --git a/examples/GA3C/atari_agent.py b/examples/GA3C/atari_agent.py old mode 100644 new mode 100755 diff --git a/examples/GA3C/atari_model.py b/examples/GA3C/atari_model.py deleted file mode 120000 index 56a659a3d3bd87171b8393135d8929a831abce16..0000000000000000000000000000000000000000 --- a/examples/GA3C/atari_model.py +++ /dev/null @@ -1 +0,0 @@ -../A2C/atari_model.py \ No newline at end of file diff --git a/examples/GA3C/atari_model.py b/examples/GA3C/atari_model.py new file mode 100755 index 0000000000000000000000000000000000000000..5cc7bb4f3578e8f6c7c560be010a3b2d223a9e70 --- /dev/null +++ b/examples/GA3C/atari_model.py @@ -0,0 +1,96 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import parl +import paddle.fluid as fluid +from parl import layers + + +class AtariModel(parl.Model): + def __init__(self, act_dim): + + self.conv1 = layers.conv2d( + num_filters=32, filter_size=8, stride=4, padding=1, act='relu') + self.conv2 = layers.conv2d( + num_filters=64, filter_size=4, stride=2, padding=2, act='relu') + self.conv3 = layers.conv2d( + num_filters=64, filter_size=3, stride=1, padding=0, act='relu') + + self.fc = layers.fc(size=512, act='relu') + + self.policy_fc = layers.fc(size=act_dim) + self.value_fc = layers.fc(size=1) + + def policy(self, obs): + """ + Args: + obs: A float32 tensor of shape [B, C, H, W] + + Returns: + policy_logits: B * ACT_DIM + """ + obs = obs / 255.0 + conv1 = self.conv1(obs) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + flatten = layers.flatten(conv3, axis=1) + fc_output = self.fc(flatten) + + policy_logits = self.policy_fc(fc_output) + return policy_logits + + def value(self, obs): + """ + Args: + obs: A float32 tensor of shape [B, C, H, W] + + Returns: + values: B + """ + obs = obs / 255.0 + conv1 = self.conv1(obs) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + flatten = layers.flatten(conv3, axis=1) + fc_output = self.fc(flatten) + + values = self.value_fc(fc_output) + values = layers.squeeze(values, axes=[1]) + return values + + def policy_and_value(self, obs): + """ + Args: + obs: A float32 tensor of shape [B, C, H, W] + + Returns: + policy_logits: B * ACT_DIM + values: B + """ + obs = obs / 255.0 + conv1 = self.conv1(obs) + conv2 = self.conv2(conv1) + conv3 = self.conv3(conv2) + + flatten = layers.flatten(conv3, axis=1) + fc_output = self.fc(flatten) + + policy_logits = self.policy_fc(fc_output) + + values = self.value_fc(fc_output) + values = layers.squeeze(values, axes=[1]) + + return policy_logits, values diff --git a/examples/GA3C/ga3c_config.py b/examples/GA3C/ga3c_config.py old mode 100644 new mode 100755 index f0cfa759009cbb345d73eb373735b8430191d020..0fb7299ecd170b9bc31218faccc79c9847eed99a --- a/examples/GA3C/ga3c_config.py +++ b/examples/GA3C/ga3c_config.py @@ -14,14 +14,14 @@ config = { #========== remote config ========== - 'server_ip': 'localhost', - 'server_port': 8037, + 'master_address': 'localhost:8010', #========== env config ========== 'env_name': 'PongNoFrameskip-v4', 'env_dim': 42, #========== learner config ========== + 'actor_num': 24, 'train_batch_size': 128, 'max_predict_batch_size': 16, 'predict_thread_num': 2, diff --git a/examples/GA3C/learner.py b/examples/GA3C/learner.py old mode 100644 new mode 100755 index 50dc77d08e71456998b05a89b6ee3e422f849bc8..b875b7f73f254ac1e7a7e50ea94fc85f99c3c46a --- a/examples/GA3C/learner.py +++ b/examples/GA3C/learner.py @@ -23,15 +23,16 @@ import parl from atari_model import AtariModel from atari_agent import AtariAgent from collections import defaultdict -from parl import RemoteManager from parl.env.atari_wrappers import wrap_deepmind -from parl.utils import logger, CSVLogger, get_gpu_count +from parl.utils import logger, get_gpu_count, tensorboard from parl.utils.scheduler import PiecewiseScheduler from parl.utils.time_stat import TimeStat from parl.utils.window_stat import WindowStat from parl.utils.rl_utils import calc_gae from parl.utils import machine_info +from actor import Actor + class Learner(object): def __init__(self, config): @@ -104,13 +105,10 @@ class Learner(object): self.sample_total_steps = 0 self.remote_manager_thread = threading.Thread( - target=self.run_remote_manager) + target=self.create_actors) self.remote_manager_thread.setDaemon(True) self.remote_manager_thread.start() - self.csv_logger = CSVLogger( - os.path.join(logger.get_dir(), 'result.csv')) - def learn_data_provider(self): """ Data generator for fluid.layers.py_reader """ @@ -181,17 +179,18 @@ class Learner(object): self.vf_loss_stat.add(vf_loss) self.entropy_stat.add(entropy) - def run_remote_manager(self): - """ Accept connection of new remote simulator and start simulation. + def create_actors(self): + """ Connect to the cluster and start sampling of the remote actor. """ - remote_manager = RemoteManager(port=self.config['server_port']) - logger.info("Waiting for the remote simulator's connection.") + parl.connect(self.config['master_address']) + + logger.info('Waiting for {} remote actors to connect.'.format( + self.config['actor_num'])) ident = 0 self.predict_output_queues = [] - while True: - remote_simulator = remote_manager.get_remote() + for i in six.moves.range(self.config['actor_num']): self.remote_count += 1 logger.info('Remote simulator count: {}'.format(self.remote_count)) @@ -202,27 +201,23 @@ class Learner(object): self.predict_output_queues.append(q) remote_thread = threading.Thread( - target=self.run_remote_sample, - args=( - remote_simulator, - ident, - )) + target=self.run_remote_sample, args=(ident, )) remote_thread.setDaemon(True) remote_thread.start() - ident += 1 - def run_remote_sample(self, remote_simulator, ident): + def run_remote_sample(self, ident): """ Interacts with remote simulator. """ + remote_actor = Actor(self.config) mem = defaultdict(list) - obs = remote_simulator.reset() + obs = remote_actor.reset() while True: self.predict_input_queue.put((ident, obs)) action, value = self.predict_output_queues[ident].get() - next_obs, reward, done = remote_simulator.step(action) + next_obs, reward, done = remote_actor.step(action) mem['obs'].append(obs) mem['actions'].append(action) @@ -245,7 +240,7 @@ class Learner(object): mem = defaultdict(list) - next_obs = remote_simulator.reset() + next_obs = remote_actor.reset() elif len(mem['obs']) == self.config['t_max'] + 1: next_value = mem['values'][-1] @@ -267,7 +262,7 @@ class Learner(object): obs = next_obs if done: - metrics = remote_simulator.get_metrics() + metrics = remote_actor.get_metrics() if metrics: self.remote_metrics_queue.put(metrics) @@ -319,8 +314,8 @@ class Learner(object): 'entropy_coeff': self.entropy_coeff, } - logger.info(metric) - self.csv_logger.log_dict(metric) + for key, value in metric.items(): + if value is not None: + tensorboard.add_scalar(key, value, self.sample_total_steps) - def close(self): - self.csv_logger.close() + logger.info(metric) diff --git a/examples/GA3C/run_simulators.sh b/examples/GA3C/run_simulators.sh deleted file mode 100644 index 9ad2825f76ee0c2ac0ac3116d2c2b661b1fd93cf..0000000000000000000000000000000000000000 --- a/examples/GA3C/run_simulators.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -export CUDA_VISIBLE_DEVICES="" - -for i in $(seq 1 24); do - python simulator.py & -done; -wait diff --git a/examples/GA3C/train.py b/examples/GA3C/train.py old mode 100644 new mode 100755 index a84f377b73335b6a044c7fd00590f38574ab1fd6..0d351ee55b10cdc884afc8f2b8607deb0ddae222 --- a/examples/GA3C/train.py +++ b/examples/GA3C/train.py @@ -20,14 +20,10 @@ def main(config): learner = Learner(config) assert config['log_metrics_interval_s'] > 0 - try: - while True: - time.sleep(config['log_metrics_interval_s']) + while True: + time.sleep(config['log_metrics_interval_s']) - learner.log_metrics() - - except KeyboardInterrupt: - learner.close() + learner.log_metrics() if __name__ == '__main__': diff --git a/examples/IMPALA/README.md b/examples/IMPALA/README.md old mode 100644 new mode 100755 index 421131cbe2d1e5571d98a5bfd79ad457045fe24e..0fcf6bd85fa690969191a025724c2797c898f0a7 --- a/examples/IMPALA/README.md +++ b/examples/IMPALA/README.md @@ -28,22 +28,23 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs). ### Distributed Training: -#### Learner -```sh -python train.py -``` +At first, We can start a local cluster with 32 CPUs: -#### Actors (Suggest: 32+ actors in 32+ CPUs) -```sh -for i in $(seq 1 32); do - python actor.py & -done; -wait +```bash +xparl start --port 8010 --cpu_num 32 ``` -You can change training settings (e.g. `env_name`, `server_ip`) in `impala_config.py`. -Training result will be saved in `log_dir/train/result.csv`. +Note that if you have started a master before, you don't have to run the above +command. For more information about the cluster, please refer to our +[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html) + +Then we can start the distributed training by running `train.py`. + +```bash +python train.py +``` ### Reference ++ [Parl Cluster Setup](https://parl.readthedocs.io/en/latest/parallel_training/setup.html). + [deepmind/scalable_agent](https://github.com/deepmind/scalable_agent) + [Ray](https://github.com/ray-project/ray) diff --git a/examples/IMPALA/actor.py b/examples/IMPALA/actor.py old mode 100644 new mode 100755 index 89a7c3bf5422ef537623eb88d96d7297d266c69c..92c04eb4eafe35dc2810c47df56dba00773ac463 --- a/examples/IMPALA/actor.py +++ b/examples/IMPALA/actor.py @@ -30,7 +30,7 @@ class Actor(object): self.config = config self.envs = [] - for _ in six.moves.range(config['env_num']): + for _ in range(config['env_num']): env = gym.make(config['env_name']) env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') self.envs.append(env) @@ -53,16 +53,16 @@ class Actor(object): def sample(self): env_sample_data = {} - for env_id in six.moves.range(self.config['env_num']): + for env_id in range(self.config['env_num']): env_sample_data[env_id] = defaultdict(list) - for i in six.moves.range(self.config['sample_batch_steps']): + for i in range(self.config['sample_batch_steps']): actions, behaviour_logits = self.agent.sample( np.stack(self.obs_batch)) next_obs_batch, reward_batch, done_batch, info_batch = \ self.vector_env.step(actions) - for env_id in six.moves.range(self.config['env_num']): + for env_id in range(self.config['env_num']): env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) env_sample_data[env_id]['actions'].append(actions[env_id]) env_sample_data[env_id]['behaviour_logits'].append( @@ -74,7 +74,7 @@ class Actor(object): # Merge data of envs sample_data = defaultdict(list) - for env_id in six.moves.range(self.config['env_num']): + for env_id in range(self.config['env_num']): for data_name in [ 'obs', 'actions', 'behaviour_logits', 'rewards', 'dones' ]: @@ -100,10 +100,3 @@ class Actor(object): def set_weights(self, weights): self.agent.set_weights(weights) - - -if __name__ == '__main__': - from impala_config import config - - actor = Actor(config) - actor.as_remote(config['server_ip'], config['server_port']) diff --git a/examples/IMPALA/atari_agent.py b/examples/IMPALA/atari_agent.py old mode 100644 new mode 100755 index 1a493c53e2189ffd57b13c63afad680b95cf872f..98d4a4c4fd3ea611f60f2d8da850265025541b4b --- a/examples/IMPALA/atari_agent.py +++ b/examples/IMPALA/atari_agent.py @@ -76,7 +76,8 @@ class AtariAgent(parl.Agent): vtrace_loss.total_loss, vtrace_loss.pi_loss, vtrace_loss.vf_loss, vtrace_loss.entropy, kl ] - self.learn_program = parl.compile(self.learn_program, total_loss) + self.learn_program = parl.compile(self.learn_program, + vtrace_loss.total_loss) def sample(self, obs_np): """ diff --git a/examples/IMPALA/atari_model.py b/examples/IMPALA/atari_model.py old mode 100644 new mode 100755 diff --git a/examples/IMPALA/impala_config.py b/examples/IMPALA/impala_config.py old mode 100644 new mode 100755 index dcfc2f5725f0a02b895e1426783274b2b944aefe..b28be1ea4451a6361e7eb466ee97468f35884a12 --- a/examples/IMPALA/impala_config.py +++ b/examples/IMPALA/impala_config.py @@ -16,14 +16,14 @@ config = { 'experiment_name': 'Pong', #========== remote config ========== - 'server_ip': 'localhost', - 'server_port': 8037, + 'master_address': 'localhost:8010', #========== env config ========== 'env_name': 'PongNoFrameskip-v4', 'env_dim': 42, #========== actor config ========== + 'actor_num': 32, 'env_num': 5, 'sample_batch_steps': 50, diff --git a/examples/IMPALA/learner.py b/examples/IMPALA/learner.py old mode 100644 new mode 100755 index 573f76ae9053ba6212e171c5026db06f2ecc3631..ab1957313e3ef93ad7e0fd06d9d24a3e3d88f0eb --- a/examples/IMPALA/learner.py +++ b/examples/IMPALA/learner.py @@ -21,13 +21,14 @@ import threading import parl from atari_model import AtariModel from atari_agent import AtariAgent -from parl import RemoteManager from parl.env.atari_wrappers import wrap_deepmind -from parl.utils import logger, CSVLogger +from parl.utils import logger, tensorboard from parl.utils.scheduler import PiecewiseScheduler from parl.utils.time_stat import TimeStat from parl.utils.window_stat import WindowStat +from actor import Actor + class Learner(object): def __init__(self, config): @@ -85,13 +86,10 @@ class Learner(object): self.sample_total_steps = 0 self.remote_manager_thread = threading.Thread( - target=self.run_remote_manager) + target=self.create_actors) self.remote_manager_thread.setDaemon(True) self.remote_manager_thread.start() - self.csv_logger = CSVLogger( - os.path.join(logger.get_dir(), 'result.csv')) - def learn_data_provider(self): """ Data generator for fluid.layers.py_reader """ @@ -139,26 +137,29 @@ class Learner(object): self.entropy_stat.add(entropy) self.kl_stat.add(kl) - def run_remote_manager(self): - """ Accept connection of new remote actor and start sampling of the remote actor. + def create_actors(self): + """ Connect to the cluster and start sampling of the remote actor. """ - remote_manager = RemoteManager(port=self.config['server_port']) - logger.info('Waiting for remote actors connecting.') - while True: - remote_actor = remote_manager.get_remote() + parl.connect(self.config['master_address']) + + logger.info('Waiting for {} remote actors to connect.'.format( + self.config['actor_num'])) + + for i in range(self.config['actor_num']): self.remote_count += 1 logger.info('Remote actor count: {}'.format(self.remote_count)) if self.start_time is None: self.start_time = time.time() - remote_thread = threading.Thread( - target=self.run_remote_sample, args=(remote_actor, )) + remote_thread = threading.Thread(target=self.run_remote_sample) remote_thread.setDaemon(True) remote_thread.start() - def run_remote_sample(self, remote_actor): + def run_remote_sample(self): """ Sample data from remote actor and update parameters of remote actor. """ + remote_actor = Actor(self.config) + cnt = 0 remote_actor.set_weights(self.cache_params) while True: @@ -237,8 +238,8 @@ class Learner(object): 'entropy_coeff': self.entropy_coeff, } - logger.info(metric) - self.csv_logger.log_dict(metric) + for key, value in metric.items(): + if value is not None: + tensorboard.add_scalar(key, value, self.sample_total_steps) - def close(self): - self.csv_logger.close() + logger.info(metric) diff --git a/examples/IMPALA/run_actors.sh b/examples/IMPALA/run_actors.sh deleted file mode 100644 index 99a0acb8678b3a7e44eae219916b1b62472d6999..0000000000000000000000000000000000000000 --- a/examples/IMPALA/run_actors.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash - -export CUDA_VISIBLE_DEVICES="" - -for i in $(seq 1 32); do - python actor.py & -done; -wait diff --git a/examples/IMPALA/train.py b/examples/IMPALA/train.py old mode 100644 new mode 100755 index bf38bb34281404975a620b04171cda53f7eaa224..a30b6f2c0383beea264b296c148b0ad3a1c541c3 --- a/examples/IMPALA/train.py +++ b/examples/IMPALA/train.py @@ -20,14 +20,10 @@ def main(config): learner = Learner(config) assert config['log_metrics_interval_s'] > 0 - try: - while True: - time.sleep(config['log_metrics_interval_s']) + while True: + time.sleep(config['log_metrics_interval_s']) - learner.log_metrics() - - except KeyboardInterrupt: - learner.close() + learner.log_metrics() if __name__ == '__main__': diff --git a/parl/__init__.py b/parl/__init__.py index cd86614906abc8e2f5921661dc7fc832c51d93b7..c23f0b91eea4378c0006687ab4b9ca8cef8fa92d 100644 --- a/parl/__init__.py +++ b/parl/__init__.py @@ -16,19 +16,13 @@ __version__ = "1.1.1" """ generates new PARL python API """ +import os -# trick to solve importing error from tensorboardX import SummaryWriter - from parl.utils.utils import _HAS_FLUID - if _HAS_FLUID: from parl.core.fluid import * from parl.core.fluid.plutils.compiler import compile -else: - print( - "WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available." - ) -from parl.remote import remote_class, RemoteManager +from parl.remote import remote_class, connect from parl import algorithms diff --git a/parl/remote/__init__.py b/parl/remote/__init__.py index bfec0c3bffe1802d4821ac6fa87ae945113fa557..208e93be8e1f8722044993d63670d64fdebfe893 100644 --- a/parl/remote/__init__.py +++ b/parl/remote/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parl.remote.master import * +from parl.remote.worker import * +from parl.remote.client import * from parl.remote.exceptions import * from parl.remote.remote_decorator import * -from parl.remote.remote_manager import * -from parl.remote.remote_object import * diff --git a/parl/remote/client.py b/parl/remote/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fb28a2dbee57f78e6de4cf973979b6df83a92cdd --- /dev/null +++ b/parl/remote/client.py @@ -0,0 +1,203 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cloudpickle +import os +import threading +import zmq +from parl.utils import to_str, to_byte, get_ip_address, logger +from parl.remote import remote_constants + + +class Client(object): + """Base class for the remote client. + + For each training task, there is a global client in the cluster which + submits jobs to the master node. Different `@parl.remote_class` objects + connect to the same global client in a training task. + + Attributes: + submit_job_socket (zmq.Context.socket): A socket which submits job to + the master node. + pyfiles (bytes): A serialized dictionary containing the code of python + files in local working directory. + + """ + + def __init__(self, master_address): + """ + Args: + master_addr (str): ip address of the master node. + """ + self.ctx = zmq.Context() + self.lock = threading.Lock() + self.heartbeat_socket_initialized = threading.Event() + self.master_is_alive = True + self.client_is_alive = True + + self._create_sockets(master_address) + self.pyfiles = self.read_local_files() + + def read_local_files(self): + """Read local python code and store them in a dictionary, which will + then be sent to the job. + + Returns: + A cloudpickled dictionary containing the python code in current + working directory. + """ + pyfiles = dict() + for file in os.listdir('./'): + if file.endswith('.py'): + with open(file, 'rb') as code_file: + code = code_file.read() + pyfiles[file] = code + return cloudpickle.dumps(pyfiles) + + def _create_sockets(self, master_address): + """ Each client has 1 sockets as start: + + (1) submit_job_socket: submits jobs to master node. + """ + + # submit_job_socket: submits job to master + self.submit_job_socket = self.ctx.socket(zmq.REQ) + self.submit_job_socket.linger = 0 + self.submit_job_socket.setsockopt( + zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) + self.submit_job_socket.connect("tcp://{}".format(master_address)) + + thread = threading.Thread(target=self._reply_heartbeat, daemon=True) + thread.start() + self.heartbeat_socket_initialized.wait() + + # check if the master is connected properly + try: + self.submit_job_socket.send_multipart([ + remote_constants.CLIENT_CONNECT_TAG, + to_byte(self.heartbeat_master_address) + ]) + _ = self.submit_job_socket.recv_multipart() + except zmq.error.Again as e: + logger.warning("[Client] Can not connect to the master, please " + "check if master is started and ensure the input " + "address {} is correct.".format(master_address)) + self.master_is_alive = False + raise Exception("Client can not connect to the master, please " + "check if master is started and ensure the input " + "address {} is correct.".format(master_address)) + + def _reply_heartbeat(self): + """Reply heartbeat signals to the Master node.""" + + socket = self.ctx.socket(zmq.REP) + socket.linger = 0 + socket.setsockopt(zmq.RCVTIMEO, + remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) + heartbeat_master_port =\ + socket.bind_to_random_port(addr="tcp://*") + self.heartbeat_master_address = "{}:{}".format(get_ip_address(), + heartbeat_master_port) + self.heartbeat_socket_initialized.set() + while self.client_is_alive and self.master_is_alive: + try: + message = socket.recv_multipart() + socket.send_multipart([remote_constants.HEARTBEAT_TAG]) + + except zmq.error.Again as e: + logger.warning("[Client] Cannot connect to the master." + "Please check if it is still alive.") + self.master_is_alive = False + socket.close(0) + logger.warning("Client exit replying heartbeat for master.") + + def submit_job(self): + """Send a job to the Master node. + + When a `@parl.remote_class` object is created, the global client + sends a job to the master node. Then the master node will allocate + a vacant job from its job pool to the remote object. + + Returns: + IP address of the job. + """ + if self.master_is_alive: + + # A lock to prevent multiple actor submit job at the same time. + self.lock.acquire() + self.submit_job_socket.send_multipart([ + remote_constants.CLIENT_SUBMIT_TAG, + to_byte(self.heartbeat_master_address) + ]) + message = self.submit_job_socket.recv_multipart() + self.lock.release() + + tag = message[0] + + if tag == remote_constants.NORMAL_TAG: + job_address = to_str(message[1]) + + # no vacant CPU resources, can not submit a new job + elif tag == remote_constants.CPU_TAG: + job_address = None + else: + raise NotImplementedError + else: + raise Exception("Client can not submit job to the master, " + "please check if master is connected.") + return job_address + + +GLOBAL_CLIENT = None + + +def connect(master_address): + """Create a global client which connects to the master node. + + .. code-block:: python + + parl.connect(master_address='localhost:1234') + + Args: + master_address (str): The address of the Master node to connect to. + + Raises: + Exception: An exception is raised if the master node is not started. + """ + + assert len(master_address.split(":")) == 2, "please input address in " +\ + "{ip}:{port} format" + global GLOBAL_CLIENT + if GLOBAL_CLIENT is None: + GLOBAL_CLIENT = Client(master_address) + + +def get_global_client(): + """Get the global client. + + Returns: + The global client. + """ + global GLOBAL_CLIENT + assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\ + " job, have you connected to the cluster by calling " +\ + "parl.connect(master_ip, master_port)?" + return GLOBAL_CLIENT + + +def disconnect(): + """Disconnect the global client from the master node.""" + global GLOBAL_CLIENT + GLOBAL_CLIENT.client_is_alive = False + GLOBAL_CLIENT = None diff --git a/parl/remote/exceptions.py b/parl/remote/exceptions.py index 60b404836bd7014862abb03ade791bd05934c1ca..33d8cd85525bb1aace8f23fbb09a945b2553ae39 100644 --- a/parl/remote/exceptions.py +++ b/parl/remote/exceptions.py @@ -13,14 +13,26 @@ # limitations under the License. +class ResourceError(Exception): + """ + No available cpu resources error. + """ + + def __init__(self, error_info): + self.error_info = error_info + + def __str__(self): + return self.error_info + + class RemoteError(Exception): """ Super class of exceptions in remote module. """ def __init__(self, func_name, error_info): - self.error_info = "[PARL remote error when calling function `{}`]:\n{}".format( - func_name, error_info) + self.error_info = "[PARL remote error when calling " +\ + "function `{}`]:\n{}".format(func_name, error_info) def __str__(self): return self.error_info @@ -52,7 +64,7 @@ class RemoteDeserializeError(RemoteError): class RemoteAttributeError(RemoteError): """ - Attribute error from remote + Attribute error from remote """ def __init__(self, func_name, error_info): diff --git a/parl/remote/job.py b/parl/remote/job.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0589338e434be5d714b41467111e9c52acf843 --- /dev/null +++ b/parl/remote/job.py @@ -0,0 +1,247 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import cloudpickle +import pickle +import sys +import tempfile +import threading +import time +import traceback +import zmq +from parl.utils import to_str, to_byte, get_ip_address, logger +from parl.utils.communication import loads_argument, loads_return,\ + dumps_argument, dumps_return +from parl.remote import remote_constants +from parl.utils.exceptions import SerializeError, DeserializeError + +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '' + + +class Job(object): + """Base class for the job. + + After establishing connection with the remote object, the job will + create a remote class instance locally and enter an infinite loop, + waiting for commands from the remote object. + """ + + def __init__(self, worker_address): + self.job_is_alive = True + self.heartbeat_socket_initialized = threading.Event() + self.worker_address = worker_address + self._create_sockets() + + def _create_sockets(self): + """Create two sockets for each job. + + (1) reply_socket: receives the command(i.e, the function name and + args) from the actual class instance, and returns the result of + the function. + (2) job_socket: sends job_address and heartbeat_address to worker. + + """ + + self.ctx = zmq.Context() + + # reply_socket: receives class, parameters and call function from + # @remote.class and send computed results to the @remote.class. + self.reply_socket = self.ctx.socket(zmq.REP) + self.reply_socket.linger = 0 + + job_port = self.reply_socket.bind_to_random_port(addr="tcp://*") + self.job_ip = get_ip_address() + self.job_address = "{}:{}".format(self.job_ip, job_port) + + reply_thread = threading.Thread( + target=self._reply_heartbeat, + args=("worker {}".format(self.worker_address), ), + daemon=True) + reply_thread.start() + self.heartbeat_socket_initialized.wait() + # job_socket: sends job_address and heartbeat_address to worker + self.job_socket = self.ctx.socket(zmq.REQ) + self.job_socket.connect("tcp://{}".format(self.worker_address)) + self.job_socket.send_multipart([ + remote_constants.NORMAL_TAG, + to_byte(self.job_address), + to_byte(self.heartbeat_worker_address) + ]) + _ = self.job_socket.recv_multipart() + + def _reply_heartbeat(self, target): + """reply heartbeat signals to the target""" + + socket = self.ctx.socket(zmq.REP) + socket.setsockopt(zmq.RCVTIMEO, + remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) + socket.linger = 0 + heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*") + self.heartbeat_worker_address = "{}:{}".format(self.job_ip, + heartbeat_worker_port) + self.heartbeat_socket_initialized.set() + # a flag to decide when to exit heartbeat loop + self.worker_is_alive = True + while self.worker_is_alive and self.job_is_alive: + try: + message = socket.recv_multipart() + socket.send_multipart([remote_constants.HEARTBEAT_TAG]) + + except zmq.error.Again as e: + logger.warning("[Job] Cannot connect to {}. ".format(target) + + "Job will quit.") + self.worker_is_alive = False + self.job_is_alive = False + + def wait_for_files(self): + """Wait for python files from remote object. + + When a remote object receives the allocated job address, it will send + the python files to the job. Later, the job will save these files to a + temporary directory and add the temporary diretory to Python's working + directory. + + Returns: + A temporary directory containing the python files. + """ + + while True: + message = self.reply_socket.recv_multipart() + tag = message[0] + if tag == remote_constants.SEND_FILE_TAG: + pyfiles = pickle.loads(message[1]) + envdir = tempfile.mkdtemp() + for file in pyfiles: + code = pyfiles[file] + file = os.path.join(envdir, file) + with open(file, 'wb') as code_file: + code_file.write(code) + self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) + return envdir + else: + logger.warning(message) + raise NotImplementedError + + def wait_for_connection(self): + """Wait for connection from the remote object. + + The remote object will send its class information and initialization + arguments to the job, these parameters are then used to create a + local instance in the job process. + + Returns: + A local instance of the remote class object. + """ + + while True: + message = self.reply_socket.recv_multipart() + tag = message[0] + if tag == remote_constants.INIT_OBJECT_TAG: + cls = cloudpickle.loads(message[1]) + args, kwargs = cloudpickle.loads(message[2]) + obj = cls(*args, **kwargs) + self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) + return obj + else: + logger.error("Message from job {}".format(message)) + raise NotImplementedError + + def run(self): + """An infinite loop waiting for commands from the remote object. + + Each job will receive two kinds of message from the remote object: + + 1. When the remote object calls a function, job will run the + function on the local instance and return the results to the + remote object. + 2. When the remote object is deleted, the job will quit and release + related computation resources. + """ + + # receive files + envdir = self.wait_for_files() + sys.path.append(envdir) + + obj = self.wait_for_connection() + + while self.job_is_alive: + message = self.reply_socket.recv_multipart() + tag = message[0] + + if tag == remote_constants.CALL_TAG: + assert obj is not None + try: + function_name = to_str(message[1]) + data = message[2] + args, kwargs = loads_argument(data) + ret = getattr(obj, function_name)(*args, **kwargs) + ret = dumps_return(ret) + + self.reply_socket.send_multipart( + [remote_constants.NORMAL_TAG, ret]) + + except Exception as e: + error_str = str(e) + logger.error(error_str) + self.job_is_alive = False + + if type(e) == AttributeError: + self.reply_socket.send_multipart([ + remote_constants.ATTRIBUTE_EXCEPTION_TAG, + to_byte(error_str) + ]) + raise AttributeError + + elif type(e) == SerializeError: + self.reply_socket.send_multipart([ + remote_constants.SERIALIZE_EXCEPTION_TAG, + to_byte(error_str) + ]) + raise SerializeError + + elif type(e) == DeserializeError: + self.reply_socket.send_multipart([ + remote_constants.DESERIALIZE_EXCEPTION_TAG, + to_byte(error_str) + ]) + + else: + traceback_str = str(traceback.format_exc()) + logger.error("traceback:\n{}".format(traceback_str)) + self.reply_socket.send_multipart([ + remote_constants.EXCEPTION_TAG, + to_byte(error_str + "\ntraceback:\n" + + traceback_str) + ]) + + # receive DELETE_TAG from actor, and stop replying worker heartbeat + elif tag == remote_constants.KILLJOB_TAG: + self.reply_socket.send_multipart([remote_constants.NORMAL_TAG]) + self.job_is_alive = False + logger.warning("An actor exits and will quit job {}.".format( + self.job_address)) + else: + logger.error("Job message: {}".format(message)) + raise NotImplementedError + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--worker_address", required=True, type=str, help="worker_address") + args = parser.parse_args() + job = Job(args.worker_address) + job.run() diff --git a/parl/remote/master.py b/parl/remote/master.py new file mode 100644 index 0000000000000000000000000000000000000000..42adc5819b2053625ade520be29ab5b5851dd784 --- /dev/null +++ b/parl/remote/master.py @@ -0,0 +1,303 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import pickle +import threading +import time +import zmq + +from collections import defaultdict +from parl.utils import to_str, to_byte, logger +from parl.remote import remote_constants + + +class Master(object): + """Base class for a master node, the control center for our cluster, which provides connections to workers and clients. + + There is only one master node in each cluster, and it is responsible for + receiving jobs from the clients and allocating computation resources to + run the jobs. + + To start a master node, we use the following xparl command line api: + + .. code-block:: python + + xparl start --port localhost:1234 + + At the same time, a local worker will be started and connect to the + master node. + + Attributes: + worker_pool (dict): A dict to store connected workers. + job_pool (list): A list to store the job address of vacant cpu, when + this number is 0, the master will refuse to create + new remote object. + client_job_dict (dict): A dict of list to record the job submitted by + each client. + job_worker_dict (dict): A dict to record the job and related worker. + client_socket (zmq.Context.socket): A socket which receives submitted + job from the client, and later sends + job_address back to the client. + worker_socket (zmq.Context.socket): A socket which receives job + addresses from the worker node. + + Args: + port: the ip port that the master node binds to. + """ + + def __init__(self, port): + logger.set_dir(os.path.expanduser('~/.parl_data/master/')) + self.lock = threading.Lock() + self.ctx = zmq.Context() + + self.client_socket = self.ctx.socket(zmq.REP) + self.client_socket.bind("tcp://*:{}".format(port)) + self.client_socket.linger = 0 + self.port = port + + self.worker_pool = {} + self.worker_locks = {} + self.job_pool = [] + + self.client_job_dict = defaultdict(list) + self.worker_job_dict = defaultdict(list) + self.job_worker_dict = {} + + self.master_is_alive = True + + def _create_worker_monitor(self, worker_heartbeat_address, worker_address): + """When a new worker connects to the master, a socket is created to + send heartbeat signals to the worker. + """ + worker_heartbeat_socket = self.ctx.socket(zmq.REQ) + worker_heartbeat_socket.linger = 0 + worker_heartbeat_socket.setsockopt( + zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) + worker_heartbeat_socket.connect("tcp://" + worker_heartbeat_address) + + connected = True + while connected and self.master_is_alive: + try: + worker_heartbeat_socket.send_multipart( + [remote_constants.HEARTBEAT_TAG]) + _ = worker_heartbeat_socket.recv_multipart() + time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) + except zmq.error.Again as e: + for job in self.worker_job_dict[worker_address]: + if job in self.job_pool: + self.job_pool.remove(job) + self.job_worker_dict.pop(job) + self.worker_job_dict.pop(worker_address) + self.worker_pool.pop(worker_address) + logger.warning("\n[Master] Cannot connect to the worker " + + "{}. ".format(worker_address) + + "Worker_pool will drop this worker.") + self._print_workers() + connected = False + except zmq.error.ZMQError as e: + break + + worker_heartbeat_socket.close(0) + logger.warning("Exit worker monitor from master.") + + def _create_client_monitor(self, client_heartbeat_address): + """when a new client connects to the master, a socket is created to + send heartbeat signals to the client. + """ + + client_heartbeat_socket = self.ctx.socket(zmq.REQ) + client_heartbeat_socket.linger = 0 + client_heartbeat_socket.setsockopt( + zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) + client_heartbeat_socket.connect("tcp://" + client_heartbeat_address) + + self.client_is_alive = True + while self.client_is_alive and self.master_is_alive: + try: + client_heartbeat_socket.send_multipart( + [remote_constants.HEARTBEAT_TAG]) + _ = client_heartbeat_socket.recv_multipart() + except zmq.error.Again as e: + self.client_is_alive = False + logger.warning("[Master] cannot connect to the client " + + "{}. ".format(client_heartbeat_address) + + "Please check if it is still alive.") + self._kill_client_jobs(client_heartbeat_address) + time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) + logger.warning("Master exits client monitor for {}.\n".format( + client_heartbeat_address)) + logger.info( + "Master connects to {} workers and have {} vacant CPUs.\n".format( + len(self.worker_pool), len(self.job_pool))) + client_heartbeat_socket.close(0) + + def _kill_client_jobs(self, client_address): + """set timeout in case the worker and client quit at the same time. + """ + jobs = self.client_job_dict[client_address] + + for job_address in jobs: + if job_address in self.job_worker_dict: + worker_address = self.job_worker_dict[job_address] + worker_socket = self.worker_pool[worker_address].worker_socket + self.worker_locks[worker_address].acquire() + worker_socket.send_multipart( + [remote_constants.KILLJOB_TAG, + to_byte(job_address)]) + try: + _ = worker_socket.recv_multipart() + except zmq.error.Again as e: + logger.warning("Error in recv kill_client_job") + self.worker_locks[worker_address].release() + self.job_worker_dict.pop(job_address) + self.client_job_dict.pop(client_address) + + def _print_workers(self): + """Display `worker_pool` infomation.""" + logger.info( + "Master connects to {} workers and have {} vacant CPUs.\n".format( + len(self.worker_pool), len(self.job_pool))) + + def _receive_message(self): + """master node will receive four types of message: (1) worker + connection; (2) worker update; (3) client connection; (4) job + submittion. + """ + message = self.client_socket.recv_multipart() + tag = message[0] + + # a new worker connects to the master + if tag == remote_constants.WORKER_CONNECT_TAG: + self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) + + elif tag == remote_constants.WORKER_INITIALIZED_TAG: + worker = pickle.loads(message[1]) + worker_heartbeat_address = to_str(message[2]) + + # maintain job & worker relations + for job_address in worker.job_pool: + self.job_worker_dict[job_address] = worker.address + self.worker_job_dict[worker.address] = worker.job_pool + self.job_pool.extend(worker.job_pool) + + # a new socket for submitting job to the worker + worker_socket = self.ctx.socket(zmq.REQ) + worker_socket.linger = 0 + worker_socket.setsockopt(zmq.RCVTIMEO, 10000) + worker_socket.connect("tcp://{}".format(worker.address)) + worker.worker_socket = worker_socket + self.worker_pool[worker.address] = worker + self.worker_locks[worker.address] = threading.Lock() + + logger.info("A new worker {} is added, ".format(worker.address) + + "cluster has {} CPUs.\n".format(len(self.job_pool))) + + # a thread for sending heartbeat signals to `worker.address` + thread = threading.Thread( + target=self._create_worker_monitor, + args=( + worker_heartbeat_address, + worker.address, + ), + daemon=True) + thread.start() + + self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) + + # a client connects to the master + elif tag == remote_constants.CLIENT_CONNECT_TAG: + client_heartbeat_address = to_str(message[1]) + logger.info( + "Client {} is connected.".format(client_heartbeat_address)) + + thread = threading.Thread( + target=self._create_client_monitor, + args=(client_heartbeat_address, ), + daemon=True) + thread.start() + self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) + + # a client submits a job to the master + elif tag == remote_constants.CLIENT_SUBMIT_TAG: + client_address = to_str(message[1]) + done_flag = False + + # check available CPU resources + if len(self.job_pool): + logger.info("Submitting job...") + job_address = self.job_pool.pop(0) + worker_address = self.job_worker_dict[job_address] + self.worker_job_dict[worker_address].remove(job_address) + self.client_socket.send_multipart( + [remote_constants.NORMAL_TAG, + to_byte(job_address)]) + self.client_job_dict[client_address].append(job_address) + self._print_workers() + else: + self.client_socket.send_multipart([remote_constants.CPU_TAG]) + + # a worker updates + elif tag == remote_constants.NEW_JOB_TAG: + worker_address = to_str(message[1]) + new_job_address = to_str(message[2]) + killed_job_address = to_str(message[3]) + + self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) + logger.info("A worker updated.") + + if killed_job_address in self.job_worker_dict: + self.job_worker_dict.pop(killed_job_address) + if killed_job_address in self.worker_job_dict[worker_address]: + self.worker_job_dict[worker_address].remove(killed_job_address) + if killed_job_address in self.job_pool: + self.job_pool.remove(killed_job_address) + + # add new job_address to job_pool + self.job_pool.append(new_job_address) + self.job_worker_dict[new_job_address] = worker_address + self.worker_job_dict[worker_address].append(new_job_address) + + self._print_workers() + + # check before start a worker + elif tag == remote_constants.NORMAL_TAG: + self.client_socket.send_multipart([remote_constants.NORMAL_TAG]) + + else: + raise NotImplementedError() + + def exit(self): + self.master_is_alive = False + self.ctx.destroy() + + def run(self): + """An infinite loop waiting for messages from the workers and + clients. + + Master node will receive four types of messages: + + 1. A new worker connects to the master node. + 2. A connected worker sending new job address after it kills an old + job. + 3. A new client connects to the master node. + 4. A connected client submits a job after a remote object is created. + """ + while self.master_is_alive: + try: + self._receive_message() + except zmq.error.ContextTerminated as e: + pass + + logger.warning("[Master] Exit master.") diff --git a/parl/remote/remote_constants.py b/parl/remote/remote_constants.py index b30aad1c0c1076561af2dd41f634f9c891171290..06269e517efee28152a6745a32a16267ac03cd26 100644 --- a/parl/remote/remote_constants.py +++ b/parl/remote/remote_constants.py @@ -12,8 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +CPU_TAG = b'[CPU]' CONNECT_TAG = b'[CONNECT]' HEARTBEAT_TAG = b'[HEARTBEAT]' +KILLJOB_TAG = b'[KILLJOB]' + +WORKER_CONNECT_TAG = b'[WORKER_CONNECT]' +WORKER_INITIALIZED_TAG = b'[WORKER_INITIALIZED]' +CLIENT_CONNECT_TAG = b'[CLIENT_CONNECT]' +CLIENT_SUBMIT_TAG = b'[CLIENT_SUBMIT]' +SEND_FILE_TAG = b'[SEND_FILE]' +SUBMIT_JOB_TAG = b'[SUBMIT_JOB]' +NEW_JOB_TAG = b'[NEW_JOB]' + +INIT_OBJECT_TAG = b'[INIT_OBJECT]' +CALL_TAG = b'[CALL]' EXCEPTION_TAG = b'[EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' @@ -24,3 +37,5 @@ NORMAL_TAG = b'[NORMAL]' # interval of heartbeat mechanism in the unit of second HEARTBEAT_INTERVAL_S = 10 +HEARTBEAT_TIMEOUT_S = 10 +HEARTBEAT_RCVTIMEO_S = HEARTBEAT_INTERVAL_S + HEARTBEAT_TIMEOUT_S * 2 diff --git a/parl/remote/remote_decorator.py b/parl/remote/remote_decorator.py index d8cc44dd6d4bd659ec25f3dde54d90f7e7b7df84..785c2d55c9a267067ebb6cd92732b2c61f67d4d1 100644 --- a/parl/remote/remote_decorator.py +++ b/parl/remote/remote_decorator.py @@ -12,235 +12,174 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -import pyarrow +import cloudpickle +import os import threading import time -import traceback import zmq -from parl.remote import remote_constants +import numpy as np + from parl.utils import get_ip_address, logger, to_str, to_byte -from parl.utils.exceptions import SerializeError, DeserializeError -from parl.utils.communication import loads_argument, dumps_return -""" -Three steps to create a remote class: -1. add a decroator(@parl.remote_class) before the definition of the class; -2. create an instance of remote class; -3. call function `as_remote` with server_ip and server_port. +from parl.utils.communication import loads_argument, loads_return,\ + dumps_argument, dumps_return +from parl.remote import remote_constants +from parl.remote.exceptions import RemoteError, RemoteAttributeError,\ + RemoteDeserializeError, RemoteSerializeError, ResourceError +from parl.remote.client import get_global_client -@parl.remote_class -Class Simulator(object): - ... -sim = Simulator() -sim.as_remote(server_ip='172.18.202.45', server_port=8001) +def remote_class(cls): + """A Python decorator that enables a class to run all its functions + remotely. -""" + Each instance of the remote class can be seemed as a task submitted + to the cluster by the global client, which is created automatically + when we call parl.connect(master_address). After global client + submits the task, the master node will send an available job address + to this remote instance. Then the remote object will send local python + files, class definition and initialization arguments to the related job. + In this way, we can run distributed applications easily and efficiently. -def remote_class(cls): - class ClientWrapper(object): - """ - Wrapper for remote class in client side. - when as_remote function called, the object initialized in the client can - handle function call from server. - """ + .. code-block:: python - def __init__(self, *args, **kwargs): - """ - Args: - args, kwargs: arguments for the initialisation of the unwrapped class. - """ - self.unwrapped = cls(*args, **kwargs) + @remote_class + class Actor(object): + def __init__(self, x): + self.x = x - self.zmq_context = None - self.poller = None + def step(self): + self.x += 1 + return self.x - # socket for connecting server and telling ip and port of client to server - self.connect_socket = None - # socket for handle function call from server side - self.reply_socket = None + actor = Actor() + actor.step() - def _create_reply_socket(self, remote_ip, remote_port): - """ - In fact, we also have a socket server in client side. This server keeps running - and waits for requests (e.g. call a function) from server side. - """ - if remote_ip is None: - remote_ip = get_ip_address() - - self.zmq_context = zmq.Context() - socket = self.zmq_context.socket(zmq.REP) - - if remote_port is None: - try: - remote_port = socket.bind_to_random_port(addr="tcp://*") - except zmq.ZMQBindError: - logger.error( - 'Can not bind to a random port, please set remote_port manually.' - ) - sys.exit(1) - else: - socket.bind("tcp://*:{}".format(remote_port)) + Returns: + A remote wrapper for the remote class. - return socket, remote_ip, remote_port + Raises: + Exception: An exception is raised if the client is not created + by `parl.connect(master_address)` beforehand. + """ - def _connect_server(self, server_ip, server_port, remote_ip, - remote_port): - """ - Create the connection between client side and server side. + class RemoteWrapper(object): + """ + Wrapper for remote class in client side. + """ + def __init__(self, *args, **kwargs): + """ Args: - server_ip(str): the ip of the server. - server_port(int): the connection port of the server. - remote_ip: the ip of the client itself. - remote_port: the port of the client itself, - which used to create reply socket. + args, kwargs: arguments for the initialization of the unwrapped + class. """ - self.reply_socket, local_ip, local_port = self._create_reply_socket( - remote_ip, remote_port) - self.reply_socket.linger = 0 + self.GLOBAL_CLIENT = get_global_client() - socket = self.zmq_context.socket(zmq.REQ) - socket.connect("tcp://{}:{}".format(server_ip, server_port)) + self.ctx = self.GLOBAL_CLIENT.ctx - logger.info("connecting {}:{}".format(server_ip, server_port)) + # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat + # finds the master is dead. + if self.GLOBAL_CLIENT.master_is_alive: + job_address = self.request_cpu_resource(self.GLOBAL_CLIENT) + else: + raise Exception("Can not submit job to the master. " + "Please check if master is still alive.") + + if job_address is None: + raise ResourceError("Cannot submit the job to the master. " + "Please add more CPU resources to the " + "master or try again later.") + + self.internal_lock = threading.Lock() + + # Send actor commands like `init` and `call` to the job. + self.job_socket = self.ctx.socket(zmq.REQ) + self.job_socket.linger = 0 + self.job_socket.connect("tcp://{}".format(job_address)) + self.job_address = job_address + + self.send_file(self.job_socket) + + try: + self.job_socket.send_multipart([ + remote_constants.INIT_OBJECT_TAG, + cloudpickle.dumps(cls), + cloudpickle.dumps([args, kwargs]) + ]) + _ = self.job_socket.recv_multipart() + except zmq.error.Again as e: + logger.error("Job socket failed.") + logger.info("[connect_job] job_address:{}".format(job_address)) + + def __del__(self): + """Delete the remote class object and release remote resources.""" + self.job_socket.send_multipart([remote_constants.KILLJOB_TAG]) + _ = self.job_socket.recv_multipart() + self.job_socket.close(0) + + def send_file(self, socket): + try: + socket.send_multipart([ + remote_constants.SEND_FILE_TAG, self.GLOBAL_CLIENT.pyfiles + ]) + _ = socket.recv_multipart() + except zmq.error.Again as e: + logger.error("Send python files failed.") + + def request_cpu_resource(self, global_client): + """Try to request cpu resource for 1 second/time for 300 times.""" + cnt = 300 + while cnt > 0: + job_address = global_client.submit_job() + if job_address is not None: + return job_address + if cnt % 30 == 0: + logger.warning("No vacant cpu resources at present, " + "will try {} times later.".format(cnt)) + cnt -= 1 + time.sleep(1) + return None - client_addr = '{}:{}'.format(local_ip, local_port) - socket.send_multipart( - [remote_constants.CONNECT_TAG, - to_byte(client_addr)]) + def __getattr__(self, attr): + """Call the function of the unwrapped class.""" - message = socket.recv_multipart() - self.client_id = message[1] - logger.info("connect server done, client_id: {}".format( - self.client_id)) - self.connect_socket = socket - self.connect_socket.linger = 0 + def wrapper(*args, **kwargs): + self.internal_lock.acquire() + data = dumps_argument(*args, **kwargs) - def _exit_remote(self): - self.poller.unregister(self.connect_socket) + self.job_socket.send_multipart( + [remote_constants.CALL_TAG, + to_byte(attr), data]) - self.connect_socket.close() - self.reply_socket.close() + message = self.job_socket.recv_multipart() + tag = message[0] - # The program may hang when destroying zmq context manually. - # It will be destroyed automatically by the garbage collection mechanism of python, - # though it may raise some exceptions in C++. + if tag == remote_constants.NORMAL_TAG: + ret = loads_return(message[1]) - #self.zmq_context.destroy() + elif tag == remote_constants.EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteError(attr, error_str) - def _heartbeat_loop(self): - """ - Periodically detect whether the server is alive or not - """ - self.poller = zmq.Poller() - self.poller.register(self.connect_socket, zmq.POLLIN) + elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteAttributeError(attr, error_str) - while True: - self.connect_socket.send_multipart( - [remote_constants.HEARTBEAT_TAG, self.client_id]) + elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteSerializeError(attr, error_str) - # wait for at most 10s to receive response - socks = dict(self.poller.poll(10000)) + elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: + error_str = to_str(message[1]) + raise RemoteDeserializeError(attr, error_str) - if socks.get(self.connect_socket) == zmq.POLLIN: - _ = self.connect_socket.recv_multipart() else: - logger.warning( - '[HeartBeat] Server no response, will exit now!') - self._exit_remote() - - break - - # HeartBeat interval 10s - time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) - - def __getattr__(self, attr): - """ - Call the function of the unwrapped class. - """ + raise NotImplementedError() - def wrapper(*args, **kwargs): - return getattr(self.unwrapped, attr)(*args, **kwargs) + self.internal_lock.release() + return ret return wrapper - def _reply_loop(self): - while True: - message = self.reply_socket.recv_multipart() - - try: - function_name = to_str(message[1]) - data = message[2] - args, kwargs = loads_argument(data) - ret = getattr(self.unwrapped, function_name)(*args, - **kwargs) - ret = dumps_return(ret) - - except Exception as e: - error_str = str(e) - logger.error(error_str) - - if type(e) == AttributeError: - self.reply_socket.send_multipart([ - remote_constants.ATTRIBUTE_EXCEPTION_TAG, - to_byte(error_str) - ]) - elif type(e) == SerializeError: - self.reply_socket.send_multipart([ - remote_constants.SERIALIZE_EXCEPTION_TAG, - to_byte(error_str) - ]) - elif type(e) == DeserializeError: - self.reply_socket.send_multipart([ - remote_constants.DESERIALIZE_EXCEPTION_TAG, - to_byte(error_str) - ]) - else: - traceback_str = str(traceback.format_exc()) - logger.error('traceback:\n{}'.format(traceback_str)) - self.reply_socket.send_multipart([ - remote_constants.EXCEPTION_TAG, - to_byte(error_str + '\ntraceback:\n' + - traceback_str) - ]) - - continue - - self.reply_socket.send_multipart( - [remote_constants.NORMAL_TAG, ret]) - - def as_remote(self, - server_ip, - server_port, - remote_ip=None, - remote_port=None): - """ - Client will connect server and wait for function calls from server side. - - Args: - server_ip(str): server's ip - server_port(int): server's port - remote_ip: the ip of the client itself. - remote_port: the port of the client itself, - which used to create reply socket. - """ - self._connect_server(server_ip, server_port, remote_ip, - remote_port) - - reply_thread = threading.Thread(target=self._reply_loop) - reply_thread.setDaemon(True) - reply_thread.start() - - self._heartbeat_loop() - - def remote_closed(self): - """ - Check whether as_remote mode is closed - """ - assert self.reply_socket is not None, 'as_remote function should be called first!' - assert self.connect_socket is not None, 'as_remote function should be called first!' - return self.reply_socket.closed and self.connect_socket.closed - - return ClientWrapper + return RemoteWrapper diff --git a/parl/remote/remote_manager.py b/parl/remote/remote_manager.py deleted file mode 100644 index bb6aeb71f17c6dc7b0efe35072117b7f9f6d8f47..0000000000000000000000000000000000000000 --- a/parl/remote/remote_manager.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from six.moves import queue -import threading -import time -import zmq -from parl.utils import logger, to_byte, to_str -from parl.remote import remote_constants -from parl.remote.remote_object import RemoteObject -""" -Two steps to build the communication with remote clients: -1. Create a RemoteManager; -2. Get remote objects by calling the function get_remote. - -```python - remote_manager = RemoteManager(port=[port]) - remote_obj = remote_manager.get_remote() -``` - -""" - - -class RemoteManager(object): - """ - Base class for network communcation. - """ - - def __init__(self, port): - """ - Args: - port(int): a local port used for connections from remote clients. - """ - self.zmq_context = zmq.Context() - socket = self.zmq_context.socket(zmq.REP) - socket.bind("tcp://*:{}".format(port)) - self.socket = socket - self.socket.linger = 0 - - self.remote_pool = queue.Queue() - self.remote_latest_timestamp = {} - - t = threading.Thread(target=self._wait_for_connection) - t.setDaemon(True) # The thread will exit when main thread exited - t.start() - - t = threading.Thread(target=self._check_remote_status) - t.setDaemon(True) # The thread will exit when main thread exited - t.start() - - def _wait_for_connection(self): - """ - A never-ending function keeps waiting for the connections from remote client. - It will put an available remote object in an internel pool, and remote object - can be obtained by calling `get_remote`. - - Note that this function has been called inside the `__init__` function. - """ - remote_id = 0 - while True: - try: - message = self.socket.recv_multipart() - tag = message[0] - - if tag == remote_constants.CONNECT_TAG: - self.socket.send_multipart([ - remote_constants.NORMAL_TAG, - to_byte('{}'.format(remote_id)) - ]) - remote_client_address = to_str(message[1]) - remote_obj = RemoteObject(remote_client_address, - self.zmq_context) - - self.remote_latest_timestamp[to_byte( - str(remote_id))] = time.time() - - logger.info('[RemoteManager] Added a new remote object.') - self.remote_pool.put(remote_obj) - - remote_id += 1 - - elif tag == remote_constants.HEARTBEAT_TAG: - self.remote_latest_timestamp[message[1]] = time.time() - self.socket.send_multipart( - [remote_constants.NORMAL_TAG, b'Server is alive.']) - else: - raise NotImplementedError() - - except zmq.ZMQError: - logger.warning('Zmq error, exiting server.') - break - - def _check_remote_status(self): - while True: - for remote_id in list(self.remote_latest_timestamp.keys()): - if time.time() - self.remote_latest_timestamp[ - remote_id] > 3 * remote_constants.HEARTBEAT_INTERVAL_S: - logger.error( - 'Remote object {} is lost, please check if anything wrong happens in the remote client' - .format(remote_id)) - self.remote_latest_timestamp.pop(remote_id) - time.sleep(3 * remote_constants.HEARTBEAT_INTERVAL_S) - - def get_remote(self): - """ - A blocking function to obtain a remote object. - - Returns: - RemoteObject - """ - return self.remote_pool.get() - - def close(self): - """ - Close RemoteManager. - """ - - self.zmq_context.destroy() diff --git a/parl/remote/remote_object.py b/parl/remote/remote_object.py deleted file mode 100644 index 319e905f9ff4c71c0ccf8251b7f8a4b86c8e2865..0000000000000000000000000000000000000000 --- a/parl/remote/remote_object.py +++ /dev/null @@ -1,98 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import threading -import zmq -from parl.remote import remote_constants -from parl.remote.exceptions import * -from parl.utils import logger, to_str, to_byte -from parl.utils.communication import dumps_argument, loads_return - - -class RemoteObject(object): - """ - Provides interface to call functions of object in remote client. - """ - - def __init__(self, remote_client_address, zmq_context=None): - """ - Args: - remote_client_address: address(ip:port) of remote client - zmq_context: zmq.Context() - """ - if zmq_context is None: - self.zmq_context = zmq.Context() - else: - self.zmq_context = zmq_context - - # socket for sending function call to remote object and receiving result - self.command_socket = None - # lock for thread safety - self.internal_lock = threading.Lock() - self._connect_remote_client(remote_client_address) - - def _connect_remote_client(self, remote_client_address): - """ - Build connection with the remote client to send function call. - """ - socket = self.zmq_context.socket(zmq.REQ) - logger.info("[connect_remote_client] client_address:{}".format( - remote_client_address)) - socket.connect("tcp://{}".format(remote_client_address)) - self.command_socket = socket - self.command_socket.linger = 0 - - def __getattr__(self, attr): - """ - Provides interface to call functions of object in remote client. - 1. send fucntion name and packed auguments to remote client; - 2. remote clinet execute the function of the object really; - 3. receive function return from remote client. - - Args: - attr(str): a function name specify which function to run. - """ - - def wrapper(*args, **kwargs): - self.internal_lock.acquire() - - data = dumps_argument(*args, **kwargs) - - self.command_socket.send_multipart( - [remote_constants.NORMAL_TAG, - to_byte(attr), data]) - - message = self.command_socket.recv_multipart() - tag = message[0] - if tag == remote_constants.NORMAL_TAG: - ret = loads_return(message[1]) - elif tag == remote_constants.EXCEPTION_TAG: - error_str = to_str(message[1]) - raise RemoteError(attr, error_str) - elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG: - error_str = to_str(message[1]) - raise RemoteAttributeError(attr, error_str) - elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG: - error_str = to_str(message[1]) - raise RemoteSerializeError(attr, error_str) - elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG: - error_str = to_str(message[1]) - raise RemoteDeserializeError(attr, error_str) - else: - raise NotImplementedError() - - self.internal_lock.release() - return ret - - return wrapper diff --git a/parl/remote/scripts.py b/parl/remote/scripts.py new file mode 100644 index 0000000000000000000000000000000000000000..b361d83a7f812ee2f0765216c85d4912d5a200e5 --- /dev/null +++ b/parl/remote/scripts.py @@ -0,0 +1,132 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import click +import locale +import os +import subprocess +import threading +import warnings +from multiprocessing import Process + +# A flag to mark if parl is started from a command line +os.environ['XPARL'] = 'True' + +# Solve `Click will abort further execution because Python 3 was configured +# to use ASCII as encoding for the environment` error. +locale.setlocale(locale.LC_ALL, "en_US.UTF-8") + +warnings.simplefilter("ignore", ResourceWarning) + + +def is_port_in_use(port): + """ Check if a port is used. + + True if the port is not available. Otherwise, this port can be used for + connection. + """ + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', int(port))) == 0 + + +def is_master_started(address): + import zmq + ctx = zmq.Context() + socket = ctx.socket(zmq.REQ) + socket.linger = 0 + socket.setsockopt(zmq.RCVTIMEO, 500) + socket.connect("tcp://{}".format(address)) + socket.send_multipart([b'[NORMAL]']) + try: + _ = socket.recv_multipart() + socket.close(0) + return True + except zmq.error.Again as e: + socket.close(0) + return False + + +@click.group() +def cli(): + pass + + +@click.command("start", short_help="Start a master node.") +@click.option("--port", help="The port to bind to.", type=str, required=True) +@click.option( + "--cpu_num", + type=int, + help="Set number of cpu manually. If not set, it will use all " + "cpus of this machine.") +def start_master(port, cpu_num): + if is_port_in_use(port): + raise Exception( + "The master address localhost:{} already in use.".format(port)) + cpu_num = str(cpu_num) if cpu_num else '' + command = [ + "python", "{}/start.py".format(__file__[:-11]), "--name", "master", + "--port", port + ] + p = subprocess.Popen(command) + + command = [ + "python", "{}/start.py".format(__file__[:-11]), "--name", "worker", + "--address", "localhost:" + str(port), "--cpu_num", + str(cpu_num) + ] + p = subprocess.Popen(command) + + +@click.command("connect", short_help="Start a worker node.") +@click.option( + "--address", help="IP address of the master node.", required=True) +@click.option( + "--cpu_num", + type=int, + help="Set number of cpu manually. If not set, it will use all " + "cpus of this machine.") +def start_worker(address, cpu_num): + if not is_master_started(address): + raise Exception("Worker can not connect to the master node, " + + "please check if the input address {} ".format( + address) + "is correct.") + cpu_num = str(cpu_num) if cpu_num else '' + command = [ + "python", "{}/start.py".format(__file__[:-11]), "--name", "worker", + "--address", address, "--cpu_num", + str(cpu_num) + ] + p = subprocess.Popen(command) + + +@click.command("stop", help="Exit the cluster.") +def stop(): + command = ("pkill -f remote/start.py") + subprocess.call([command], shell=True) + command = ("pkill -f job.py") + p = subprocess.call([command], shell=True) + + +cli.add_command(start_worker) +cli.add_command(start_master) +cli.add_command(stop) + + +def main(): + return cli() + + +if __name__ == "__main__": + main() diff --git a/parl/remote/start.py b/parl/remote/start.py new file mode 100644 index 0000000000000000000000000000000000000000..d9aa231db65a04ee410d7df6660d9eaa75150828 --- /dev/null +++ b/parl/remote/start.py @@ -0,0 +1,52 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import threading +from parl.remote import Master, Worker + + +def main(args): + """Start a master or a worker through: + + 1. xparl start --port 1234 + 2. xparl connect --address localhost:1234 --cpu_num 8 + + """ + + if args.name == 'master': + port = args.port + master = Master(port) + master.run() + + elif args.name == 'worker': + address = args.address + cpu_num = int(args.cpu_num) if args.cpu_num else None + worker = Worker(address, cpu_num) + worker.run() + + else: + raise NotImplementedError + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + '--name', default='master', type=str, help='master/worker') + parser.add_argument('--port', default='1234', type=str) + parser.add_argument('--address', default='localhost:1234', type=str) + parser.add_argument('--cpu_num', default='', type=str) + args = parser.parse_args() + main(args) diff --git a/parl/remote/tests/remote_decorator_test.py b/parl/remote/tests/remote_decorator_test.py deleted file mode 100644 index c5d70b2afbf446e9f4f4b78210d06d5d1365dd75..0000000000000000000000000000000000000000 --- a/parl/remote/tests/remote_decorator_test.py +++ /dev/null @@ -1,73 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import parl -import unittest - - -@parl.remote_class -class Simulator: - def __init__(self, arg1, arg2=None): - self.arg1 = arg1 - self.arg2 = arg2 - - def get_arg1(self): - return self.arg1 - - def get_arg2(self): - return self.arg2 - - def set_arg1(self, value): - self.arg1 = value - - def set_arg2(self, value): - self.arg2 = value - - -class TestRemoteDecorator(unittest.TestCase): - def test_instance_in_local(self): - local_sim = Simulator(1, 2) - - self.assertEqual(local_sim.get_arg1(), 1) - self.assertEqual(local_sim.get_arg2(), 2) - - local_sim.set_arg1(3) - local_sim.set_arg2(4) - - self.assertEqual(local_sim.get_arg1(), 3) - self.assertEqual(local_sim.get_arg2(), 4) - - def test_instance_in_local_with_wrong_getattr_get_variable(self): - local_sim = Simulator(1, 2) - - try: - local_sim.get_arg3() - except AttributeError: - return - - assert False # This line should not be executed. - - def test_instance_in_local_with_wrong_getattr_set_variable(self): - local_sim = Simulator(1, 2) - - try: - local_sim.set_arg3(3) - except AttributeError: - return - - assert False # This line should not be executed. - - -if __name__ == '__main__': - unittest.main() diff --git a/parl/remote/tests/remote_test.py b/parl/remote/tests/remote_test.py deleted file mode 100644 index 64ba62c756f77842eaa8c46f991df4cf3ebc6fdb..0000000000000000000000000000000000000000 --- a/parl/remote/tests/remote_test.py +++ /dev/null @@ -1,310 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -import parl -import threading -import unittest -from parl.remote import * - - -class UnableSerializeObject(object): - def __init__(self): - # threading.Lock() can not be serialized - self.lock = threading.Lock() - - -@parl.remote_class -class Simulator: - def __init__(self, arg1, arg2=None): - self.arg1 = arg1 - self.arg2 = arg2 - - def get_arg1(self): - return self.arg1 - - def get_arg2(self): - return self.arg2 - - def set_arg1(self, value): - self.arg1 = value - - def set_arg2(self, value): - self.arg2 = value - - def get_unable_serialize_object(self): - return UnableSerializeObject() - - def add_one(self, value): - value += 1 - return value - - def will_raise_exeception_func(self): - x = 1 / 0 - - -class TestRemote(unittest.TestCase): - def _setUp(self, server_port): - self.sim = Simulator(1, arg2=2) - - # run client in a new thread to fake a remote client - self.client_thread = threading.Thread( - target=self.sim.as_remote, args=( - 'localhost', - server_port, - )) - self.client_thread.setDaemon(True) - self.client_thread.start() - - self.remote_manager = RemoteManager(port=server_port) - - def test_remote_object(self): - server_port = 17770 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - self.assertEqual(remote_sim.get_arg1(), 1) - self.assertEqual(remote_sim.get_arg2(), 2) - - ret = remote_sim.set_arg1(3) - self.assertIsNone(ret) - ret = remote_sim.set_arg2(4) - self.assertIsNone(ret) - - self.assertEqual(remote_sim.get_arg1(), 3) - self.assertEqual(remote_sim.get_arg2(), 4) - - def test_remote_object_with_wrong_getattr_get_variable(self): - server_port = 17771 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.get_arg3() - except RemoteAttributeError as e: - logger.info('Expected exception: {}'.format(e)) - # expected - return - - assert False - - def test_remote_object_with_wrong_getattr_set_variable(self): - server_port = 17772 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.set_arg3(3) - except RemoteAttributeError as e: - logger.info('Expected exception: {}'.format(e)) - # expected - return - - assert False - - def test_remote_object_with_wrong_argument(self): - server_port = 17773 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.set_arg1(wrong_arg=1) - except RemoteError as e: - logger.info('Expected exception: {}'.format(e)) - # expected - return - - assert False - - def test_remote_object_with_unable_serialize_argument(self): - server_port = 17774 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.set_arg1(UnableSerializeObject()) - except SerializeError as e: - logger.info('Expected exception: {}'.format(e)) - # expected - return - - assert False - - def test_remote_object_with_unable_serialize_return(self): - server_port = 17775 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.get_unable_serialize_object() - except RemoteSerializeError as e: - # expected - logger.info('Expected exception: {}'.format(e)) - return - - assert False - - def test_multi_remote_object(self): - server_port = 17776 - self._setUp(server_port) - - time.sleep(1) - # run second client - sim2 = Simulator(11, arg2=22) - client_thread2 = threading.Thread( - target=sim2.as_remote, args=( - 'localhost', - server_port, - )) - client_thread2.setDaemon(True) - client_thread2.start() - - time.sleep(1) - remote_sim1 = self.remote_manager.get_remote() - remote_sim2 = self.remote_manager.get_remote() - - self.assertEqual(remote_sim1.get_arg1(), 1) - self.assertEqual(remote_sim2.get_arg1(), 11) - - def test_multi_remote_object_with_one_failed(self): - server_port = 17777 - self._setUp(server_port) - - time.sleep(1) - # run second client - sim2 = Simulator(11, arg2=22) - client_thread2 = threading.Thread( - target=sim2.as_remote, args=( - 'localhost', - server_port, - )) - client_thread2.setDaemon(True) - client_thread2.start() - - time.sleep(1) - remote_sim1 = self.remote_manager.get_remote() - remote_sim2 = self.remote_manager.get_remote() - - try: - # make remote sim1 failed - remote_sim1.get_arg3() - except: - pass - - self.assertEqual(remote_sim2.get_arg1(), 11) - - # Todo(@zenghongsheng): - # zmq will raise unexpected C++ exception when closing context, - # remove this unittest for now. - #def test_heartbeat_after_server_closed(self): - # server_port = 17778 - # self._setUp(server_port) - - # remote_sim = self.remote_manager.get_remote() - - # time.sleep(1) - # self.remote_manager.close() - - # # heartbeat interval (10s) + max waiting reply (10s) - # time.sleep(20) - - # logger.info('check self.sim.remote_closed') - # self.assertTrue(self.sim.remote_closed()) - - def test_set_client_ip_port_manually(self): - server_port = 17779 - self._setUp(server_port) - - time.sleep(1) - # run second client - sim2 = Simulator(11, arg2=22) - client_thread2 = threading.Thread( - target=sim2.as_remote, - args=( - 'localhost', - server_port, - 'localhost', - 6666, - )) - client_thread2.setDaemon(True) - client_thread2.start() - - time.sleep(1) - remote_sim1 = self.remote_manager.get_remote() - remote_sim2 = self.remote_manager.get_remote() - - self.assertEqual(remote_sim1.get_arg1(), 1) - self.assertEqual(remote_sim2.get_arg1(), 11) - - def test_thread_safe_of_remote_module(self): - server_port = 17780 - self._setUp(server_port) - - time.sleep(1) - - thread_num = 10 - for _ in range(thread_num): - # run clients in backend - sim = Simulator(11, arg2=22) - client_thread = threading.Thread( - target=sim.as_remote, args=( - 'localhost', - server_port, - )) - client_thread.setDaemon(True) - client_thread.start() - - time.sleep(1) - threads = [] - for _ in range(thread_num): - remote_sim = self.remote_manager.get_remote() - t = threading.Thread( - target=self._run_remote_add, args=(remote_sim, )) - t.start() - threads.append(t) - - for t in threads: - t.join() - - def test_remote_object_with_call_raise_exception_function(self): - server_port = 17781 - self._setUp(server_port) - - remote_sim = self.remote_manager.get_remote() - - try: - remote_sim.will_raise_exeception_func() - except RemoteError as e: - assert 'Traceback (most recent call last)' in str(e) - logger.info('Expected exception: {}'.format(e)) - # expected - return - - assert False - - def _run_remote_add(self, remote_sim): - value = 0 - for i in range(1000): - value = remote_sim.add_one(value) - assert value == i + 1 - - -if __name__ == '__main__': - unittest.main() diff --git a/parl/remote/worker.py b/parl/remote/worker.py new file mode 100644 index 0000000000000000000000000000000000000000..2001db2dc8279b42bb57173be2b1a39fb0d50d83 --- /dev/null +++ b/parl/remote/worker.py @@ -0,0 +1,314 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import cloudpickle +import multiprocessing +import os +import subprocess +import sys +import time +import threading +import zmq + +from parl.utils import get_ip_address, to_byte, to_str, logger +from parl.remote import remote_constants + + +class WorkerInfo(object): + """A WorkerInfo object records the computation resources of a worker. + """ + + def __init__(self, address, cpu_num, job_pool): + self.address = address + self.cpu_num = cpu_num + self.job_pool = job_pool + self.worker_socket = None + + +class Worker(object): + """Worker provides the cpu computation resources for the cluster. + + A worker node is connected to the master node and will send its + computation resources information to the master node. When a worker + node is created, it will start `cpu_num` empty jobs and these jobs' + ip addresses will be send to the master node. Further, when an old + job is killed, worker will start a new job and send the new job ip + address to the master node. + + To start a worker, we use the following xparl command line api: + + .. code-block:: python + + xparl connect --address localhost:1234 --cpu_num 8 + + Attributes: + job_pid (dict): A dict of subprocess id and its address. + master_address (str): Master's ip address. + request_master_socket (zmq.Context.socket): A socket which sends job + address to the master node. + reply_master_socket (zmq.Context.socket): A socket which accepts + submitted job from master + node. + reply_job_socket (zmq.Context.socket): A socket which receives + job_address from the job. + Args: + master_address (str): IP address of the master node. + cpu_num (int): Number of cpu to be used on the worker. + """ + + def __init__(self, master_address, cpu_num=None): + self.lock = threading.Lock() + self.heartbeat_socket_initialized = threading.Event() + self.ctx = zmq.Context.instance() + self.job_pid = {} + self.master_address = master_address + self.master_is_alive = True + self.worker_is_alive = True + self._set_cpu_num(cpu_num) + self._create_sockets() + self._create_worker() + + def _set_cpu_num(self, cpu_num=None): + """set useable cpu number for worker""" + if cpu_num is not None: + assert isinstance( + cpu_num, int + ), "cpu_num should be INT type, please check the input type." + self.cpu_num = cpu_num + else: + self.cpu_num = multiprocessing.cpu_count() + + def _create_sockets(self): + """ Each worker has three sockets at start: + + (1) request_master_socket: sends job address to master node. + (2) reply_master_socket: accepts submitted job from master node. + (3) reply_job_socket: receives job_address from subprocess. + + When a job is start, a new heartbeat socket is created to receive + heartbeat signal from the job. + + """ + + # request_master_socket: sends job address to master + self.request_master_socket = self.ctx.socket(zmq.REQ) + self.request_master_socket.linger = 0 + + # wait for 0.5 second to check whether master is started + self.request_master_socket.setsockopt(zmq.RCVTIMEO, 500) + self.request_master_socket.connect("tcp://" + self.master_address) + + # reply_master_socket: receives submitted job from master + self.reply_master_socket = self.ctx.socket(zmq.REP) + self.reply_master_socket.linger = 0 + self.worker_ip = get_ip_address() + reply_master_port = self.reply_master_socket.bind_to_random_port( + "tcp://*") + self.reply_master_address = "{}:{}".format(self.worker_ip, + reply_master_port) + logger.set_dir( + os.path.expanduser('~/.parl_data/worker/{}'.format( + self.reply_master_address))) + # reply_job_socket: receives job_address from subprocess + self.reply_job_socket = self.ctx.socket(zmq.REP) + self.reply_job_socket.linger = 0 + reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*") + self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port) + + def _create_worker(self): + """create a WorkerInfo instance and send it to the master.""" + try: + self.request_master_socket.send_multipart( + [remote_constants.WORKER_CONNECT_TAG]) + _ = self.request_master_socket.recv_multipart() + except zmq.error.Again as e: + logger.error("Can not connect to the master, " + "please check if master is started.") + self.master_is_alive = False + return + + self._init_jobs() + self.request_master_socket.setsockopt( + zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) + + self.worker = WorkerInfo(self.reply_master_address, self.cpu_num, + list(self.job_pid.keys())) + reply_thread = threading.Thread( + target=self._reply_heartbeat, + args=("master {}".format(self.master_address), ), + daemon=True) + reply_thread.start() + self.heartbeat_socket_initialized.wait() + + self.request_master_socket.send_multipart([ + remote_constants.WORKER_INITIALIZED_TAG, + cloudpickle.dumps(self.worker), + to_byte(self.heartbeat_master_address) + ]) + _ = self.request_master_socket.recv_multipart() + + def _init_job(self): + """Create one job.""" + command = [ + "python", "{}/job.py".format(__file__[:-10]), "--worker_address", + self.reply_job_address + ] + + with open(os.devnull, "w") as null: + pid = subprocess.Popen(command, stdout=null, stderr=null) + + self.lock.acquire() + job_message = self.reply_job_socket.recv_multipart() + self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG]) + job_address = to_str(job_message[1]) + heartbeat_job_address = to_str(job_message[2]) + self.job_pid[job_address] = pid + self.lock.release() + + # a thread for sending heartbeat signals to job + thread = threading.Thread( + target=self._create_job_monitor, + args=( + job_address, + heartbeat_job_address, + ), + daemon=True) + thread.start() + return job_address + + def _init_jobs(self): + """Create cpu_num jobs when the worker is created.""" + job_threads = [] + for _ in range(self.cpu_num): + t = threading.Thread(target=self._init_job, daemon=True) + t.start() + job_threads.append(t) + for th in job_threads: + th.join() + + def _kill_job(self, job_address): + """kill problematic job process and update worker information""" + if job_address in self.job_pid: + self.job_pid[job_address].kill() + self.job_pid.pop(job_address) + logger.warning("Worker kills job process {},".format(job_address)) + + # When a old job is killed, the worker will create a new job. + if self.master_is_alive: + new_job_address = self._init_job() + + self.lock.acquire() + self.request_master_socket.send_multipart([ + remote_constants.NEW_JOB_TAG, + to_byte(self.reply_master_address), + to_byte(new_job_address), + to_byte(job_address) + ]) + _ = self.request_master_socket.recv_multipart() + self.lock.release() + + def _create_job_monitor(self, job_address, heartbeat_job_address): + """Sending heartbeat signals to check target's status""" + + # job_heartbeat_socket: sends heartbeat signal to job + job_heartbeat_socket = self.ctx.socket(zmq.REQ) + job_heartbeat_socket.linger = 0 + job_heartbeat_socket.setsockopt( + zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000) + job_heartbeat_socket.connect("tcp://" + heartbeat_job_address) + + job_is_alive = True + while job_is_alive and self.master_is_alive: + try: + job_heartbeat_socket.send_multipart( + [remote_constants.HEARTBEAT_TAG]) + _ = job_heartbeat_socket.recv_multipart() + time.sleep(remote_constants.HEARTBEAT_INTERVAL_S) + + except zmq.error.Again as e: + job_is_alive = False + if job_address in self.job_pid: + logger.warning("[Worker] No heartbeat reply from the job, " + "will kill {}.".format(job_address)) + self._kill_job(job_address) + + except zmq.error.ZMQError as e: + break + + job_heartbeat_socket.close(0) + + def _reply_heartbeat(self, target): + """Worker will kill its jobs when it lost connection with the master. + """ + + socket = self.ctx.socket(zmq.REP) + socket.linger = 0 + socket.setsockopt(zmq.RCVTIMEO, + remote_constants.HEARTBEAT_RCVTIMEO_S * 1000) + heartbeat_master_port =\ + socket.bind_to_random_port("tcp://*") + self.heartbeat_master_address = "{}:{}".format(self.worker_ip, + heartbeat_master_port) + self.heartbeat_socket_initialized.set() + logger.info("[Worker] Connect to the master node successfully. " + "({} CPUs)".format(self.cpu_num)) + while self.master_is_alive: + try: + message = socket.recv_multipart() + socket.send_multipart([remote_constants.HEARTBEAT_TAG]) + except zmq.error.Again as e: + self.master_is_alive = False + for job_address in list(self.job_pid.keys()): + self._kill_job(job_address) + except zmq.error.ContextTerminated as e: + break + socket.close(0) + logger.warning("Worker exit replying heartbeat for master.") + if self.worker_is_alive: + self.exit() + + def exit(self): + """Exit all zmq sockets related to the worker.""" + self.worker_is_alive = False + self.ctx.destroy() + + def run(self): + """An infinite loop waiting for killing job commands from + the mater node. + + After creating `cpu_num` jobs and sending job addresses to the master + node, a worker will keep waiting for killing job commands from master + node to release computation resources occupied by a dead client. Then + the worker will kill the jobs related to the dead client and create + new jobs and update job addresses to the master node. + """ + + while self.master_is_alive and self.worker_is_alive: + try: + message = self.reply_master_socket.recv_multipart() + tag = message[0] + + if tag == remote_constants.KILLJOB_TAG: + job_address = to_str(message[1]) + self.reply_master_socket.send_multipart( + [remote_constants.NORMAL_TAG]) + self._kill_job(job_address) + + else: + raise NotImplementedError + except zmq.error.ZMQError as e: + self.worker_is_alive = False + + logger.warning("[Worker] Exit Worker {}.".format( + self.reply_master_address)) diff --git a/parl/utils/logger.py b/parl/utils/logger.py index 61353a4f4ab02e99fc9e1983a95020c9594e9427..3e84cefc19b2867bee5da9317a2c497e11f389e6 100644 --- a/parl/utils/logger.py +++ b/parl/utils/logger.py @@ -76,12 +76,34 @@ def _getlogger(): logger = logging.getLogger('PARL') logger.propagate = False logger.setLevel(logging.DEBUG) - handler = logging.StreamHandler(sys.stdout) - handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S')) - logger.addHandler(handler) + if 'XPARL' not in os.environ: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S')) + logger.addHandler(handler) return logger +def create_file_after_first_call(func_name): + def call(*args, **kwargs): + global _logger + if LOG_DIR is None: + + basename = os.path.basename(mod.__file__) + if basename.rfind('.') == -1: + basename = basename + else: + basename = basename[:basename.rfind('.')] + auto_dirname = os.path.join('log_dir', basename) + + shutil.rmtree(auto_dirname, ignore_errors=True) + set_dir(auto_dirname) + + func = getattr(_logger, func_name) + func(*args, **kwargs) + + return call + + _logger = _getlogger() _LOGGING_METHOD = [ 'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', @@ -90,8 +112,9 @@ _LOGGING_METHOD = [ # export logger functions for func in _LOGGING_METHOD: - locals()[func] = getattr(_logger, func) + locals()[func] = create_file_after_first_call(func) __all__.append(func) + # export Level information _LOGGING_LEVEL = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] for level in _LOGGING_LEVEL: @@ -100,7 +123,7 @@ for level in _LOGGING_LEVEL: def _set_file(path): - global _FILE_HANDLER + global _FILE_HANDLER, _logger if os.path.isfile(path): try: os.remove(path) @@ -114,16 +137,19 @@ def _set_file(path): def set_level(level): + global _logger, LOG_DIR # To set level, need create new handler - set_dir(get_dir()) + if LOG_DIR is not None: + set_dir(get_dir()) _logger.setLevel(level) def set_dir(dirname): - global LOG_DIR, _FILE_HANDLER + global LOG_DIR, _FILE_HANDLER, _logger if _FILE_HANDLER: # unload and close the old file handler, so that we may safely delete the logger directory _logger.removeHandler(_FILE_HANDLER) + _FILE_HANDLER.close() del _FILE_HANDLER if not os.path.isdir(dirname): @@ -137,10 +163,6 @@ def get_dir(): # Will save log to log_dir/main_file_name/log.log by default + mod = sys.modules['__main__'] -if hasattr(mod, '__file__'): - basename = os.path.basename(mod.__file__) - auto_dirname = os.path.join('log_dir', basename[:basename.rfind('.')]) - shutil.rmtree(auto_dirname, ignore_errors=True) - set_dir(auto_dirname) - _logger.info("Argv: " + ' '.join(sys.argv)) +_logger.info("Argv: " + ' '.join(sys.argv)) diff --git a/parl/utils/machine_info.py b/parl/utils/machine_info.py index 754af8b93e7c49650e248e2883d03b790188cbea..8a04ee124c5db475d01c3d870097eb6ba9d3456b 100644 --- a/parl/utils/machine_info.py +++ b/parl/utils/machine_info.py @@ -61,7 +61,7 @@ def get_gpu_count(): """get avaliable gpu count Returns: - gpu_count: int + gpu_count: int """ gpu_count = 0 @@ -77,7 +77,7 @@ def get_gpu_count(): logger.info( 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) except: - logger.warn('Cannot find available GPU devices, using CPU now.') + logger.warning('Cannot find available GPU devices, using CPU now.') gpu_count = 0 else: try: @@ -85,7 +85,7 @@ def get_gpu_count(): "-L"])).count('UUID') logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count)) except: - logger.warn('Cannot find available GPU devices, using CPU now.') + logger.warning('Cannot find available GPU devices, using CPU now.') gpu_count = 0 return gpu_count @@ -100,7 +100,7 @@ def is_gpu_available(): if utils._HAS_FLUID: from paddle import fluid if ret is True and not fluid.is_compiled_with_cuda(): - logger.warn("Found non-empty CUDA_VISIBLE_DEVICES. \ + logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \ But PARL found that Paddle was not complied with CUDA, which may cause issues." - ) + ) return ret