提交 b29a1ec1 编写于 作者: F fuyw 提交者: Bo Zhou

first pr (#113)

* first pr

* start a worker when the master is started.

* First PR & Fix logger bugs.

* update docs for a2c, impala and ga3c

* update doc

* yapf modification

* update logger

* yapf correct

* yapf

* setup.py

* old setup.py

* worker 86
上级 7cc9677c
...@@ -24,22 +24,23 @@ Mean episode reward in training process after 10 million sample steps. ...@@ -24,22 +24,23 @@ Mean episode reward in training process after 10 million sample steps.
### Distributed Training ### Distributed Training
#### Learner At first, We can start a local cluster with 5 CPUs:
```sh
python train.py
```
#### Actors (Suggest: 5 actors in 5 CPUs) ```bash
```sh xparl start --port 8010 --cpu_num 5
for i in $(seq 1 5); do
python actor.py &
done;
wait
``` ```
You can change training settings (e.g. `env_name`, `server_ip`) in `a2c_config.py`. Note that if you have started a master before, you don't have to run the above
Training result will be saved in `log_dir/train/result.csv`. command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`.
```bash
python train.py
```
### Reference ### Reference
+ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
+ [Ray](https://github.com/ray-project/ray) + [Ray](https://github.com/ray-project/ray)
+ [OpenAI Baselines: ACKTR & A2C](https://openai.com/blog/baselines-acktr-a2c/) + [OpenAI Baselines: ACKTR & A2C](https://openai.com/blog/baselines-acktr-a2c/)
...@@ -13,9 +13,9 @@ ...@@ -13,9 +13,9 @@
# limitations under the License. # limitations under the License.
config = { config = {
#========== remote config ========== #========== remote config ==========
'server_ip': 'localhost', 'master_address': 'localhost:8010',
'server_port': 8037,
#========== env config ========== #========== env config ==========
'env_name': 'PongNoFrameskip-v4', 'env_name': 'PongNoFrameskip-v4',
......
...@@ -15,8 +15,6 @@ ...@@ -15,8 +15,6 @@
import gym import gym
import numpy as np import numpy as np
import parl import parl
import six
import parl
from atari_model import AtariModel from atari_model import AtariModel
from collections import defaultdict from collections import defaultdict
from atari_agent import AtariAgent from atari_agent import AtariAgent
...@@ -31,7 +29,7 @@ class Actor(object): ...@@ -31,7 +29,7 @@ class Actor(object):
self.config = config self.config = config
self.envs = [] self.envs = []
for _ in six.moves.range(config['env_num']): for _ in range(config['env_num']):
env = gym.make(config['env_name']) env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env) self.envs.append(env)
...@@ -54,16 +52,16 @@ class Actor(object): ...@@ -54,16 +52,16 @@ class Actor(object):
sample_data = defaultdict(list) sample_data = defaultdict(list)
env_sample_data = {} env_sample_data = {}
for env_id in six.moves.range(self.config['env_num']): for env_id in range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list) env_sample_data[env_id] = defaultdict(list)
for i in six.moves.range(self.config['sample_batch_steps']): for i in range(self.config['sample_batch_steps']):
actions_batch, values_batch = self.agent.sample( actions_batch, values_batch = self.agent.sample(
np.stack(self.obs_batch)) np.stack(self.obs_batch))
next_obs_batch, reward_batch, done_batch, info_batch = \ next_obs_batch, reward_batch, done_batch, info_batch = \
self.vector_env.step(actions_batch) self.vector_env.step(actions_batch)
for env_id in six.moves.range(self.config['env_num']): for env_id in range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) env_sample_data[env_id]['obs'].append(self.obs_batch[env_id])
env_sample_data[env_id]['actions'].append( env_sample_data[env_id]['actions'].append(
actions_batch[env_id]) actions_batch[env_id])
...@@ -115,10 +113,3 @@ class Actor(object): ...@@ -115,10 +113,3 @@ class Actor(object):
def set_weights(self, params): def set_weights(self, params):
self.agent.set_weights(params) self.agent.set_weights(params)
if __name__ == '__main__':
from a2c_config import config
actor = Actor(config)
actor.as_remote(config['server_ip'], config['server_port'])
文件模式从 100644 更改为 100755
文件模式从 100644 更改为 100755
...@@ -23,14 +23,16 @@ import parl ...@@ -23,14 +23,16 @@ import parl
from atari_model import AtariModel from atari_model import AtariModel
from atari_agent import AtariAgent from atari_agent import AtariAgent
from collections import defaultdict from collections import defaultdict
from parl import RemoteManager
from parl.env.atari_wrappers import wrap_deepmind from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger, get_gpu_count from parl.utils import logger, get_gpu_count, tensorboard
from parl.utils.scheduler import PiecewiseScheduler from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat from parl.utils.window_stat import WindowStat
from parl.utils import machine_info from parl.utils import machine_info
from actor import Actor
class Learner(object): class Learner(object):
def __init__(self, config): def __init__(self, config):
...@@ -78,20 +80,17 @@ class Learner(object): ...@@ -78,20 +80,17 @@ class Learner(object):
self.sample_total_steps = 0 self.sample_total_steps = 0
self.params_queues = [] self.params_queues = []
self.run_remote_manager() self.create_actors()
self.csv_logger = CSVLogger( def create_actors(self):
os.path.join(logger.get_dir(), 'result.csv')) """ Connect to the cluster and start sampling of the remote actor.
def run_remote_manager(self):
""" Accept connection of new remote actor and start sampling of the remote actor.
""" """
remote_manager = RemoteManager(port=self.config['server_port']) parl.connect(self.config['master_address'])
logger.info('Waiting for {} remote actors to connect.'.format( logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num'])) self.config['actor_num']))
for i in six.moves.range(self.config['actor_num']): for i in six.moves.range(self.config['actor_num']):
remote_actor = remote_manager.get_remote()
params_queue = queue.Queue() params_queue = queue.Queue()
self.params_queues.append(params_queue) self.params_queues.append(params_queue)
...@@ -99,17 +98,18 @@ class Learner(object): ...@@ -99,17 +98,18 @@ class Learner(object):
logger.info('Remote actor count: {}'.format(self.remote_count)) logger.info('Remote actor count: {}'.format(self.remote_count))
remote_thread = threading.Thread( remote_thread = threading.Thread(
target=self.run_remote_sample, target=self.run_remote_sample, args=(params_queue, ))
args=(remote_actor, params_queue))
remote_thread.setDaemon(True) remote_thread.setDaemon(True)
remote_thread.start() remote_thread.start()
logger.info('All remote actors are ready, begin to learn.') logger.info('All remote actors are ready, begin to learn.')
self.start_time = time.time() self.start_time = time.time()
def run_remote_sample(self, remote_actor, params_queue): def run_remote_sample(self, params_queue):
""" Sample data from remote actor and update parameters of remote actor. """ Sample data from remote actor and update parameters of remote actor.
""" """
remote_actor = Actor(self.config)
cnt = 0 cnt = 0
while True: while True:
latest_params = params_queue.get() latest_params = params_queue.get()
...@@ -128,7 +128,7 @@ class Learner(object): ...@@ -128,7 +128,7 @@ class Learner(object):
""" """
1. kick off all actors to synchronize parameters and sample data; 1. kick off all actors to synchronize parameters and sample data;
2. collect sample data of all actors; 2. collect sample data of all actors;
3. update parameters. 3. update parameters.
""" """
latest_params = self.agent.get_weights() latest_params = self.agent.get_weights()
...@@ -208,11 +208,11 @@ class Learner(object): ...@@ -208,11 +208,11 @@ class Learner(object):
'entropy_coeff': self.entropy_coeff, 'entropy_coeff': self.entropy_coeff,
} }
for key, value in metric.items():
if value is not None:
tensorboard.add_scalar(key, value, self.sample_total_steps)
logger.info(metric) logger.info(metric)
self.csv_logger.log_dict(metric)
def should_stop(self): def should_stop(self):
return self.sample_total_steps >= self.config['max_sample_steps'] return self.sample_total_steps >= self.config['max_sample_steps']
def close(self):
self.csv_logger.close()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""
for i in $(seq 1 5); do
python actor.py &
done;
wait
...@@ -14,22 +14,18 @@ ...@@ -14,22 +14,18 @@
import time import time
from learner import Learner from learner import Learner
import parl
def main(config): def main(config):
learner = Learner(config) learner = Learner(config)
assert config['log_metrics_interval_s'] > 0 assert config['log_metrics_interval_s'] > 0
try: while not learner.should_stop():
while not learner.should_stop(): start = time.time()
start = time.time() while time.time() - start < config['log_metrics_interval_s']:
while time.time() - start < config['log_metrics_interval_s']: learner.step()
learner.step() learner.log_metrics()
learner.log_metrics()
learner.close()
except KeyboardInterrupt:
learner.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -21,26 +21,28 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill ...@@ -21,26 +21,28 @@ Results with one learner (in a P40 GPU) and 24 simulators (in 12 CPU) in 10 mill
+ gym + gym
+ atari-py + atari-py
### Distributed Training ### Distributed Training
#### Learner At first, We can start a local cluster with 24 CPUs:
```sh
python train.py
```
#### Simulators (Suggest: 24 simulators in 12+ CPUs) ```bash
```sh xparl start --port 8010 --cpu_num 24
for i in $(seq 1 24); do
python simulator.py &
done;
wait
``` ```
You can change training settings (e.g. `env_name`, `server_ip`) in `ga3c_config.py`. Note that if you have started a master before, you don't have to run the above
Training result will be saved in `log_dir/train/result.csv`. command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`.
```bash
python train.py
```
[Tips] The performance can be influenced dramatically in a slower computational environment, especially when training with low-speed CPUs. It may be caused by the policy-lag problem. [Tips] The performance can be influenced dramatically in a slower computational
environment, especially when training with low-speed CPUs. It may be caused by
the policy-lag problem.
### Reference ### Reference
+ [Parl](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
+ [tensorpack](https://github.com/tensorpack/tensorpack) + [tensorpack](https://github.com/tensorpack/tensorpack)
...@@ -15,13 +15,12 @@ ...@@ -15,13 +15,12 @@
import gym import gym
import numpy as np import numpy as np
import parl import parl
import six
from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls from parl.env.atari_wrappers import wrap_deepmind, MonitorEnv, get_wrapper_by_cls
from collections import defaultdict from collections import defaultdict
@parl.remote_class @parl.remote_class
class Simulator(object): class Actor(object):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
...@@ -45,10 +44,3 @@ class Simulator(object): ...@@ -45,10 +44,3 @@ class Simulator(object):
metrics['episode_rewards'].append(episode_rewards) metrics['episode_rewards'].append(episode_rewards)
metrics['episode_steps'].append(episode_steps) metrics['episode_steps'].append(episode_steps)
return metrics return metrics
if __name__ == '__main__':
from ga3c_config import config
simulator = Simulator(config)
simulator.as_remote(config['server_ip'], config['server_port'])
文件模式从 100644 更改为 100755
../A2C/atari_model.py
\ No newline at end of file
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import paddle.fluid as fluid
from parl import layers
class AtariModel(parl.Model):
def __init__(self, act_dim):
self.conv1 = layers.conv2d(
num_filters=32, filter_size=8, stride=4, padding=1, act='relu')
self.conv2 = layers.conv2d(
num_filters=64, filter_size=4, stride=2, padding=2, act='relu')
self.conv3 = layers.conv2d(
num_filters=64, filter_size=3, stride=1, padding=0, act='relu')
self.fc = layers.fc(size=512, act='relu')
self.policy_fc = layers.fc(size=act_dim)
self.value_fc = layers.fc(size=1)
def policy(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
return policy_logits
def value(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return values
def policy_and_value(self, obs):
"""
Args:
obs: A float32 tensor of shape [B, C, H, W]
Returns:
policy_logits: B * ACT_DIM
values: B
"""
obs = obs / 255.0
conv1 = self.conv1(obs)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
flatten = layers.flatten(conv3, axis=1)
fc_output = self.fc(flatten)
policy_logits = self.policy_fc(fc_output)
values = self.value_fc(fc_output)
values = layers.squeeze(values, axes=[1])
return policy_logits, values
...@@ -14,14 +14,14 @@ ...@@ -14,14 +14,14 @@
config = { config = {
#========== remote config ========== #========== remote config ==========
'server_ip': 'localhost', 'master_address': 'localhost:8010',
'server_port': 8037,
#========== env config ========== #========== env config ==========
'env_name': 'PongNoFrameskip-v4', 'env_name': 'PongNoFrameskip-v4',
'env_dim': 42, 'env_dim': 42,
#========== learner config ========== #========== learner config ==========
'actor_num': 24,
'train_batch_size': 128, 'train_batch_size': 128,
'max_predict_batch_size': 16, 'max_predict_batch_size': 16,
'predict_thread_num': 2, 'predict_thread_num': 2,
......
...@@ -23,15 +23,16 @@ import parl ...@@ -23,15 +23,16 @@ import parl
from atari_model import AtariModel from atari_model import AtariModel
from atari_agent import AtariAgent from atari_agent import AtariAgent
from collections import defaultdict from collections import defaultdict
from parl import RemoteManager
from parl.env.atari_wrappers import wrap_deepmind from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger, get_gpu_count from parl.utils import logger, get_gpu_count, tensorboard
from parl.utils.scheduler import PiecewiseScheduler from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat from parl.utils.window_stat import WindowStat
from parl.utils.rl_utils import calc_gae from parl.utils.rl_utils import calc_gae
from parl.utils import machine_info from parl.utils import machine_info
from actor import Actor
class Learner(object): class Learner(object):
def __init__(self, config): def __init__(self, config):
...@@ -104,13 +105,10 @@ class Learner(object): ...@@ -104,13 +105,10 @@ class Learner(object):
self.sample_total_steps = 0 self.sample_total_steps = 0
self.remote_manager_thread = threading.Thread( self.remote_manager_thread = threading.Thread(
target=self.run_remote_manager) target=self.create_actors)
self.remote_manager_thread.setDaemon(True) self.remote_manager_thread.setDaemon(True)
self.remote_manager_thread.start() self.remote_manager_thread.start()
self.csv_logger = CSVLogger(
os.path.join(logger.get_dir(), 'result.csv'))
def learn_data_provider(self): def learn_data_provider(self):
""" Data generator for fluid.layers.py_reader """ Data generator for fluid.layers.py_reader
""" """
...@@ -181,17 +179,18 @@ class Learner(object): ...@@ -181,17 +179,18 @@ class Learner(object):
self.vf_loss_stat.add(vf_loss) self.vf_loss_stat.add(vf_loss)
self.entropy_stat.add(entropy) self.entropy_stat.add(entropy)
def run_remote_manager(self): def create_actors(self):
""" Accept connection of new remote simulator and start simulation. """ Connect to the cluster and start sampling of the remote actor.
""" """
remote_manager = RemoteManager(port=self.config['server_port']) parl.connect(self.config['master_address'])
logger.info("Waiting for the remote simulator's connection.")
logger.info('Waiting for {} remote actors to connect.'.format(
self.config['actor_num']))
ident = 0 ident = 0
self.predict_output_queues = [] self.predict_output_queues = []
while True: for i in six.moves.range(self.config['actor_num']):
remote_simulator = remote_manager.get_remote()
self.remote_count += 1 self.remote_count += 1
logger.info('Remote simulator count: {}'.format(self.remote_count)) logger.info('Remote simulator count: {}'.format(self.remote_count))
...@@ -202,27 +201,23 @@ class Learner(object): ...@@ -202,27 +201,23 @@ class Learner(object):
self.predict_output_queues.append(q) self.predict_output_queues.append(q)
remote_thread = threading.Thread( remote_thread = threading.Thread(
target=self.run_remote_sample, target=self.run_remote_sample, args=(ident, ))
args=(
remote_simulator,
ident,
))
remote_thread.setDaemon(True) remote_thread.setDaemon(True)
remote_thread.start() remote_thread.start()
ident += 1 ident += 1
def run_remote_sample(self, remote_simulator, ident): def run_remote_sample(self, ident):
""" Interacts with remote simulator. """ Interacts with remote simulator.
""" """
remote_actor = Actor(self.config)
mem = defaultdict(list) mem = defaultdict(list)
obs = remote_simulator.reset() obs = remote_actor.reset()
while True: while True:
self.predict_input_queue.put((ident, obs)) self.predict_input_queue.put((ident, obs))
action, value = self.predict_output_queues[ident].get() action, value = self.predict_output_queues[ident].get()
next_obs, reward, done = remote_simulator.step(action) next_obs, reward, done = remote_actor.step(action)
mem['obs'].append(obs) mem['obs'].append(obs)
mem['actions'].append(action) mem['actions'].append(action)
...@@ -245,7 +240,7 @@ class Learner(object): ...@@ -245,7 +240,7 @@ class Learner(object):
mem = defaultdict(list) mem = defaultdict(list)
next_obs = remote_simulator.reset() next_obs = remote_actor.reset()
elif len(mem['obs']) == self.config['t_max'] + 1: elif len(mem['obs']) == self.config['t_max'] + 1:
next_value = mem['values'][-1] next_value = mem['values'][-1]
...@@ -267,7 +262,7 @@ class Learner(object): ...@@ -267,7 +262,7 @@ class Learner(object):
obs = next_obs obs = next_obs
if done: if done:
metrics = remote_simulator.get_metrics() metrics = remote_actor.get_metrics()
if metrics: if metrics:
self.remote_metrics_queue.put(metrics) self.remote_metrics_queue.put(metrics)
...@@ -319,8 +314,8 @@ class Learner(object): ...@@ -319,8 +314,8 @@ class Learner(object):
'entropy_coeff': self.entropy_coeff, 'entropy_coeff': self.entropy_coeff,
} }
logger.info(metric) for key, value in metric.items():
self.csv_logger.log_dict(metric) if value is not None:
tensorboard.add_scalar(key, value, self.sample_total_steps)
def close(self): logger.info(metric)
self.csv_logger.close()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""
for i in $(seq 1 24); do
python simulator.py &
done;
wait
...@@ -20,14 +20,10 @@ def main(config): ...@@ -20,14 +20,10 @@ def main(config):
learner = Learner(config) learner = Learner(config)
assert config['log_metrics_interval_s'] > 0 assert config['log_metrics_interval_s'] > 0
try: while True:
while True: time.sleep(config['log_metrics_interval_s'])
time.sleep(config['log_metrics_interval_s'])
learner.log_metrics() learner.log_metrics()
except KeyboardInterrupt:
learner.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -28,22 +28,23 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs). ...@@ -28,22 +28,23 @@ Result with one learner (in a P40 GPU) and 32 actors (in 32 CPUs).
### Distributed Training: ### Distributed Training:
#### Learner At first, We can start a local cluster with 32 CPUs:
```sh
python train.py
```
#### Actors (Suggest: 32+ actors in 32+ CPUs) ```bash
```sh xparl start --port 8010 --cpu_num 32
for i in $(seq 1 32); do
python actor.py &
done;
wait
``` ```
You can change training settings (e.g. `env_name`, `server_ip`) in `impala_config.py`. Note that if you have started a master before, you don't have to run the above
Training result will be saved in `log_dir/train/result.csv`. command. For more information about the cluster, please refer to our
[documentation](https://parl.readthedocs.io/en/latest/parallel_training/setup.html)
Then we can start the distributed training by running `train.py`.
```bash
python train.py
```
### Reference ### Reference
+ [Parl Cluster Setup](https://parl.readthedocs.io/en/latest/parallel_training/setup.html).
+ [deepmind/scalable_agent](https://github.com/deepmind/scalable_agent) + [deepmind/scalable_agent](https://github.com/deepmind/scalable_agent)
+ [Ray](https://github.com/ray-project/ray) + [Ray](https://github.com/ray-project/ray)
...@@ -30,7 +30,7 @@ class Actor(object): ...@@ -30,7 +30,7 @@ class Actor(object):
self.config = config self.config = config
self.envs = [] self.envs = []
for _ in six.moves.range(config['env_num']): for _ in range(config['env_num']):
env = gym.make(config['env_name']) env = gym.make(config['env_name'])
env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW') env = wrap_deepmind(env, dim=config['env_dim'], obs_format='NCHW')
self.envs.append(env) self.envs.append(env)
...@@ -53,16 +53,16 @@ class Actor(object): ...@@ -53,16 +53,16 @@ class Actor(object):
def sample(self): def sample(self):
env_sample_data = {} env_sample_data = {}
for env_id in six.moves.range(self.config['env_num']): for env_id in range(self.config['env_num']):
env_sample_data[env_id] = defaultdict(list) env_sample_data[env_id] = defaultdict(list)
for i in six.moves.range(self.config['sample_batch_steps']): for i in range(self.config['sample_batch_steps']):
actions, behaviour_logits = self.agent.sample( actions, behaviour_logits = self.agent.sample(
np.stack(self.obs_batch)) np.stack(self.obs_batch))
next_obs_batch, reward_batch, done_batch, info_batch = \ next_obs_batch, reward_batch, done_batch, info_batch = \
self.vector_env.step(actions) self.vector_env.step(actions)
for env_id in six.moves.range(self.config['env_num']): for env_id in range(self.config['env_num']):
env_sample_data[env_id]['obs'].append(self.obs_batch[env_id]) env_sample_data[env_id]['obs'].append(self.obs_batch[env_id])
env_sample_data[env_id]['actions'].append(actions[env_id]) env_sample_data[env_id]['actions'].append(actions[env_id])
env_sample_data[env_id]['behaviour_logits'].append( env_sample_data[env_id]['behaviour_logits'].append(
...@@ -74,7 +74,7 @@ class Actor(object): ...@@ -74,7 +74,7 @@ class Actor(object):
# Merge data of envs # Merge data of envs
sample_data = defaultdict(list) sample_data = defaultdict(list)
for env_id in six.moves.range(self.config['env_num']): for env_id in range(self.config['env_num']):
for data_name in [ for data_name in [
'obs', 'actions', 'behaviour_logits', 'rewards', 'dones' 'obs', 'actions', 'behaviour_logits', 'rewards', 'dones'
]: ]:
...@@ -100,10 +100,3 @@ class Actor(object): ...@@ -100,10 +100,3 @@ class Actor(object):
def set_weights(self, weights): def set_weights(self, weights):
self.agent.set_weights(weights) self.agent.set_weights(weights)
if __name__ == '__main__':
from impala_config import config
actor = Actor(config)
actor.as_remote(config['server_ip'], config['server_port'])
...@@ -76,7 +76,8 @@ class AtariAgent(parl.Agent): ...@@ -76,7 +76,8 @@ class AtariAgent(parl.Agent):
vtrace_loss.total_loss, vtrace_loss.pi_loss, vtrace_loss.total_loss, vtrace_loss.pi_loss,
vtrace_loss.vf_loss, vtrace_loss.entropy, kl vtrace_loss.vf_loss, vtrace_loss.entropy, kl
] ]
self.learn_program = parl.compile(self.learn_program, total_loss) self.learn_program = parl.compile(self.learn_program,
vtrace_loss.total_loss)
def sample(self, obs_np): def sample(self, obs_np):
""" """
......
文件模式从 100644 更改为 100755
...@@ -16,14 +16,14 @@ config = { ...@@ -16,14 +16,14 @@ config = {
'experiment_name': 'Pong', 'experiment_name': 'Pong',
#========== remote config ========== #========== remote config ==========
'server_ip': 'localhost', 'master_address': 'localhost:8010',
'server_port': 8037,
#========== env config ========== #========== env config ==========
'env_name': 'PongNoFrameskip-v4', 'env_name': 'PongNoFrameskip-v4',
'env_dim': 42, 'env_dim': 42,
#========== actor config ========== #========== actor config ==========
'actor_num': 32,
'env_num': 5, 'env_num': 5,
'sample_batch_steps': 50, 'sample_batch_steps': 50,
......
...@@ -21,13 +21,14 @@ import threading ...@@ -21,13 +21,14 @@ import threading
import parl import parl
from atari_model import AtariModel from atari_model import AtariModel
from atari_agent import AtariAgent from atari_agent import AtariAgent
from parl import RemoteManager
from parl.env.atari_wrappers import wrap_deepmind from parl.env.atari_wrappers import wrap_deepmind
from parl.utils import logger, CSVLogger from parl.utils import logger, tensorboard
from parl.utils.scheduler import PiecewiseScheduler from parl.utils.scheduler import PiecewiseScheduler
from parl.utils.time_stat import TimeStat from parl.utils.time_stat import TimeStat
from parl.utils.window_stat import WindowStat from parl.utils.window_stat import WindowStat
from actor import Actor
class Learner(object): class Learner(object):
def __init__(self, config): def __init__(self, config):
...@@ -85,13 +86,10 @@ class Learner(object): ...@@ -85,13 +86,10 @@ class Learner(object):
self.sample_total_steps = 0 self.sample_total_steps = 0
self.remote_manager_thread = threading.Thread( self.remote_manager_thread = threading.Thread(
target=self.run_remote_manager) target=self.create_actors)
self.remote_manager_thread.setDaemon(True) self.remote_manager_thread.setDaemon(True)
self.remote_manager_thread.start() self.remote_manager_thread.start()
self.csv_logger = CSVLogger(
os.path.join(logger.get_dir(), 'result.csv'))
def learn_data_provider(self): def learn_data_provider(self):
""" Data generator for fluid.layers.py_reader """ Data generator for fluid.layers.py_reader
""" """
...@@ -139,26 +137,29 @@ class Learner(object): ...@@ -139,26 +137,29 @@ class Learner(object):
self.entropy_stat.add(entropy) self.entropy_stat.add(entropy)
self.kl_stat.add(kl) self.kl_stat.add(kl)
def run_remote_manager(self): def create_actors(self):
""" Accept connection of new remote actor and start sampling of the remote actor. """ Connect to the cluster and start sampling of the remote actor.
""" """
remote_manager = RemoteManager(port=self.config['server_port']) parl.connect(self.config['master_address'])
logger.info('Waiting for remote actors connecting.')
while True: logger.info('Waiting for {} remote actors to connect.'.format(
remote_actor = remote_manager.get_remote() self.config['actor_num']))
for i in range(self.config['actor_num']):
self.remote_count += 1 self.remote_count += 1
logger.info('Remote actor count: {}'.format(self.remote_count)) logger.info('Remote actor count: {}'.format(self.remote_count))
if self.start_time is None: if self.start_time is None:
self.start_time = time.time() self.start_time = time.time()
remote_thread = threading.Thread( remote_thread = threading.Thread(target=self.run_remote_sample)
target=self.run_remote_sample, args=(remote_actor, ))
remote_thread.setDaemon(True) remote_thread.setDaemon(True)
remote_thread.start() remote_thread.start()
def run_remote_sample(self, remote_actor): def run_remote_sample(self):
""" Sample data from remote actor and update parameters of remote actor. """ Sample data from remote actor and update parameters of remote actor.
""" """
remote_actor = Actor(self.config)
cnt = 0 cnt = 0
remote_actor.set_weights(self.cache_params) remote_actor.set_weights(self.cache_params)
while True: while True:
...@@ -237,8 +238,8 @@ class Learner(object): ...@@ -237,8 +238,8 @@ class Learner(object):
'entropy_coeff': self.entropy_coeff, 'entropy_coeff': self.entropy_coeff,
} }
logger.info(metric) for key, value in metric.items():
self.csv_logger.log_dict(metric) if value is not None:
tensorboard.add_scalar(key, value, self.sample_total_steps)
def close(self): logger.info(metric)
self.csv_logger.close()
#!/bin/bash
export CUDA_VISIBLE_DEVICES=""
for i in $(seq 1 32); do
python actor.py &
done;
wait
...@@ -20,14 +20,10 @@ def main(config): ...@@ -20,14 +20,10 @@ def main(config):
learner = Learner(config) learner = Learner(config)
assert config['log_metrics_interval_s'] > 0 assert config['log_metrics_interval_s'] > 0
try: while True:
while True: time.sleep(config['log_metrics_interval_s'])
time.sleep(config['log_metrics_interval_s'])
learner.log_metrics() learner.log_metrics()
except KeyboardInterrupt:
learner.close()
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -16,19 +16,13 @@ __version__ = "1.1.1" ...@@ -16,19 +16,13 @@ __version__ = "1.1.1"
""" """
generates new PARL python API generates new PARL python API
""" """
import os
# trick to solve importing error
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from parl.utils.utils import _HAS_FLUID from parl.utils.utils import _HAS_FLUID
if _HAS_FLUID: if _HAS_FLUID:
from parl.core.fluid import * from parl.core.fluid import *
from parl.core.fluid.plutils.compiler import compile from parl.core.fluid.plutils.compiler import compile
else:
print(
"WARNING:PARL: Failed to import paddle. Only APIs for parallelization are available."
)
from parl.remote import remote_class, RemoteManager from parl.remote import remote_class, connect
from parl import algorithms from parl import algorithms
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -12,7 +12,8 @@ ...@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from parl.remote.master import *
from parl.remote.worker import *
from parl.remote.client import *
from parl.remote.exceptions import * from parl.remote.exceptions import *
from parl.remote.remote_decorator import * from parl.remote.remote_decorator import *
from parl.remote.remote_manager import *
from parl.remote.remote_object import *
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cloudpickle
import os
import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.remote import remote_constants
class Client(object):
"""Base class for the remote client.
For each training task, there is a global client in the cluster which
submits jobs to the master node. Different `@parl.remote_class` objects
connect to the same global client in a training task.
Attributes:
submit_job_socket (zmq.Context.socket): A socket which submits job to
the master node.
pyfiles (bytes): A serialized dictionary containing the code of python
files in local working directory.
"""
def __init__(self, master_address):
"""
Args:
master_addr (str): ip address of the master node.
"""
self.ctx = zmq.Context()
self.lock = threading.Lock()
self.heartbeat_socket_initialized = threading.Event()
self.master_is_alive = True
self.client_is_alive = True
self._create_sockets(master_address)
self.pyfiles = self.read_local_files()
def read_local_files(self):
"""Read local python code and store them in a dictionary, which will
then be sent to the job.
Returns:
A cloudpickled dictionary containing the python code in current
working directory.
"""
pyfiles = dict()
for file in os.listdir('./'):
if file.endswith('.py'):
with open(file, 'rb') as code_file:
code = code_file.read()
pyfiles[file] = code
return cloudpickle.dumps(pyfiles)
def _create_sockets(self, master_address):
""" Each client has 1 sockets as start:
(1) submit_job_socket: submits jobs to master node.
"""
# submit_job_socket: submits job to master
self.submit_job_socket = self.ctx.socket(zmq.REQ)
self.submit_job_socket.linger = 0
self.submit_job_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.submit_job_socket.connect("tcp://{}".format(master_address))
thread = threading.Thread(target=self._reply_heartbeat, daemon=True)
thread.start()
self.heartbeat_socket_initialized.wait()
# check if the master is connected properly
try:
self.submit_job_socket.send_multipart([
remote_constants.CLIENT_CONNECT_TAG,
to_byte(self.heartbeat_master_address)
])
_ = self.submit_job_socket.recv_multipart()
except zmq.error.Again as e:
logger.warning("[Client] Can not connect to the master, please "
"check if master is started and ensure the input "
"address {} is correct.".format(master_address))
self.master_is_alive = False
raise Exception("Client can not connect to the master, please "
"check if master is started and ensure the input "
"address {} is correct.".format(master_address))
def _reply_heartbeat(self):
"""Reply heartbeat signals to the Master node."""
socket = self.ctx.socket(zmq.REP)
socket.linger = 0
socket.setsockopt(zmq.RCVTIMEO,
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_master_port =\
socket.bind_to_random_port(addr="tcp://*")
self.heartbeat_master_address = "{}:{}".format(get_ip_address(),
heartbeat_master_port)
self.heartbeat_socket_initialized.set()
while self.client_is_alive and self.master_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning("[Client] Cannot connect to the master."
"Please check if it is still alive.")
self.master_is_alive = False
socket.close(0)
logger.warning("Client exit replying heartbeat for master.")
def submit_job(self):
"""Send a job to the Master node.
When a `@parl.remote_class` object is created, the global client
sends a job to the master node. Then the master node will allocate
a vacant job from its job pool to the remote object.
Returns:
IP address of the job.
"""
if self.master_is_alive:
# A lock to prevent multiple actor submit job at the same time.
self.lock.acquire()
self.submit_job_socket.send_multipart([
remote_constants.CLIENT_SUBMIT_TAG,
to_byte(self.heartbeat_master_address)
])
message = self.submit_job_socket.recv_multipart()
self.lock.release()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
job_address = to_str(message[1])
# no vacant CPU resources, can not submit a new job
elif tag == remote_constants.CPU_TAG:
job_address = None
else:
raise NotImplementedError
else:
raise Exception("Client can not submit job to the master, "
"please check if master is connected.")
return job_address
GLOBAL_CLIENT = None
def connect(master_address):
"""Create a global client which connects to the master node.
.. code-block:: python
parl.connect(master_address='localhost:1234')
Args:
master_address (str): The address of the Master node to connect to.
Raises:
Exception: An exception is raised if the master node is not started.
"""
assert len(master_address.split(":")) == 2, "please input address in " +\
"{ip}:{port} format"
global GLOBAL_CLIENT
if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address)
def get_global_client():
"""Get the global client.
Returns:
The global client.
"""
global GLOBAL_CLIENT
assert GLOBAL_CLIENT is not None, "Cannot get the client to submit the" +\
" job, have you connected to the cluster by calling " +\
"parl.connect(master_ip, master_port)?"
return GLOBAL_CLIENT
def disconnect():
"""Disconnect the global client from the master node."""
global GLOBAL_CLIENT
GLOBAL_CLIENT.client_is_alive = False
GLOBAL_CLIENT = None
...@@ -13,14 +13,26 @@ ...@@ -13,14 +13,26 @@
# limitations under the License. # limitations under the License.
class ResourceError(Exception):
"""
No available cpu resources error.
"""
def __init__(self, error_info):
self.error_info = error_info
def __str__(self):
return self.error_info
class RemoteError(Exception): class RemoteError(Exception):
""" """
Super class of exceptions in remote module. Super class of exceptions in remote module.
""" """
def __init__(self, func_name, error_info): def __init__(self, func_name, error_info):
self.error_info = "[PARL remote error when calling function `{}`]:\n{}".format( self.error_info = "[PARL remote error when calling " +\
func_name, error_info) "function `{}`]:\n{}".format(func_name, error_info)
def __str__(self): def __str__(self):
return self.error_info return self.error_info
...@@ -52,7 +64,7 @@ class RemoteDeserializeError(RemoteError): ...@@ -52,7 +64,7 @@ class RemoteDeserializeError(RemoteError):
class RemoteAttributeError(RemoteError): class RemoteAttributeError(RemoteError):
""" """
Attribute error from remote Attribute error from remote
""" """
def __init__(self, func_name, error_info): def __init__(self, func_name, error_info):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import cloudpickle
import pickle
import sys
import tempfile
import threading
import time
import traceback
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import loads_argument, loads_return,\
dumps_argument, dumps_return
from parl.remote import remote_constants
from parl.utils.exceptions import SerializeError, DeserializeError
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
class Job(object):
"""Base class for the job.
After establishing connection with the remote object, the job will
create a remote class instance locally and enter an infinite loop,
waiting for commands from the remote object.
"""
def __init__(self, worker_address):
self.job_is_alive = True
self.heartbeat_socket_initialized = threading.Event()
self.worker_address = worker_address
self._create_sockets()
def _create_sockets(self):
"""Create two sockets for each job.
(1) reply_socket: receives the command(i.e, the function name and
args) from the actual class instance, and returns the result of
the function.
(2) job_socket: sends job_address and heartbeat_address to worker.
"""
self.ctx = zmq.Context()
# reply_socket: receives class, parameters and call function from
# @remote.class and send computed results to the @remote.class.
self.reply_socket = self.ctx.socket(zmq.REP)
self.reply_socket.linger = 0
job_port = self.reply_socket.bind_to_random_port(addr="tcp://*")
self.job_ip = get_ip_address()
self.job_address = "{}:{}".format(self.job_ip, job_port)
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("worker {}".format(self.worker_address), ),
daemon=True)
reply_thread.start()
self.heartbeat_socket_initialized.wait()
# job_socket: sends job_address and heartbeat_address to worker
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.connect("tcp://{}".format(self.worker_address))
self.job_socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte(self.job_address),
to_byte(self.heartbeat_worker_address)
])
_ = self.job_socket.recv_multipart()
def _reply_heartbeat(self, target):
"""reply heartbeat signals to the target"""
socket = self.ctx.socket(zmq.REP)
socket.setsockopt(zmq.RCVTIMEO,
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
socket.linger = 0
heartbeat_worker_port = socket.bind_to_random_port(addr="tcp://*")
self.heartbeat_worker_address = "{}:{}".format(self.job_ip,
heartbeat_worker_port)
self.heartbeat_socket_initialized.set()
# a flag to decide when to exit heartbeat loop
self.worker_is_alive = True
while self.worker_is_alive and self.job_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
logger.warning("[Job] Cannot connect to {}. ".format(target) +
"Job will quit.")
self.worker_is_alive = False
self.job_is_alive = False
def wait_for_files(self):
"""Wait for python files from remote object.
When a remote object receives the allocated job address, it will send
the python files to the job. Later, the job will save these files to a
temporary directory and add the temporary diretory to Python's working
directory.
Returns:
A temporary directory containing the python files.
"""
while True:
message = self.reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.SEND_FILE_TAG:
pyfiles = pickle.loads(message[1])
envdir = tempfile.mkdtemp()
for file in pyfiles:
code = pyfiles[file]
file = os.path.join(envdir, file)
with open(file, 'wb') as code_file:
code_file.write(code)
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return envdir
else:
logger.warning(message)
raise NotImplementedError
def wait_for_connection(self):
"""Wait for connection from the remote object.
The remote object will send its class information and initialization
arguments to the job, these parameters are then used to create a
local instance in the job process.
Returns:
A local instance of the remote class object.
"""
while True:
message = self.reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.INIT_OBJECT_TAG:
cls = cloudpickle.loads(message[1])
args, kwargs = cloudpickle.loads(message[2])
obj = cls(*args, **kwargs)
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
return obj
else:
logger.error("Message from job {}".format(message))
raise NotImplementedError
def run(self):
"""An infinite loop waiting for commands from the remote object.
Each job will receive two kinds of message from the remote object:
1. When the remote object calls a function, job will run the
function on the local instance and return the results to the
remote object.
2. When the remote object is deleted, the job will quit and release
related computation resources.
"""
# receive files
envdir = self.wait_for_files()
sys.path.append(envdir)
obj = self.wait_for_connection()
while self.job_is_alive:
message = self.reply_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.CALL_TAG:
assert obj is not None
try:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
ret = getattr(obj, function_name)(*args, **kwargs)
ret = dumps_return(ret)
self.reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
except Exception as e:
error_str = str(e)
logger.error(error_str)
self.job_is_alive = False
if type(e) == AttributeError:
self.reply_socket.send_multipart([
remote_constants.ATTRIBUTE_EXCEPTION_TAG,
to_byte(error_str)
])
raise AttributeError
elif type(e) == SerializeError:
self.reply_socket.send_multipart([
remote_constants.SERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
raise SerializeError
elif type(e) == DeserializeError:
self.reply_socket.send_multipart([
remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
else:
traceback_str = str(traceback.format_exc())
logger.error("traceback:\n{}".format(traceback_str))
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + "\ntraceback:\n" +
traceback_str)
])
# receive DELETE_TAG from actor, and stop replying worker heartbeat
elif tag == remote_constants.KILLJOB_TAG:
self.reply_socket.send_multipart([remote_constants.NORMAL_TAG])
self.job_is_alive = False
logger.warning("An actor exits and will quit job {}.".format(
self.job_address))
else:
logger.error("Job message: {}".format(message))
raise NotImplementedError
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--worker_address", required=True, type=str, help="worker_address")
args = parser.parse_args()
job = Job(args.worker_address)
job.run()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pickle
import threading
import time
import zmq
from collections import defaultdict
from parl.utils import to_str, to_byte, logger
from parl.remote import remote_constants
class Master(object):
"""Base class for a master node, the control center for our cluster, which provides connections to workers and clients.
There is only one master node in each cluster, and it is responsible for
receiving jobs from the clients and allocating computation resources to
run the jobs.
To start a master node, we use the following xparl command line api:
.. code-block:: python
xparl start --port localhost:1234
At the same time, a local worker will be started and connect to the
master node.
Attributes:
worker_pool (dict): A dict to store connected workers.
job_pool (list): A list to store the job address of vacant cpu, when
this number is 0, the master will refuse to create
new remote object.
client_job_dict (dict): A dict of list to record the job submitted by
each client.
job_worker_dict (dict): A dict to record the job and related worker.
client_socket (zmq.Context.socket): A socket which receives submitted
job from the client, and later sends
job_address back to the client.
worker_socket (zmq.Context.socket): A socket which receives job
addresses from the worker node.
Args:
port: the ip port that the master node binds to.
"""
def __init__(self, port):
logger.set_dir(os.path.expanduser('~/.parl_data/master/'))
self.lock = threading.Lock()
self.ctx = zmq.Context()
self.client_socket = self.ctx.socket(zmq.REP)
self.client_socket.bind("tcp://*:{}".format(port))
self.client_socket.linger = 0
self.port = port
self.worker_pool = {}
self.worker_locks = {}
self.job_pool = []
self.client_job_dict = defaultdict(list)
self.worker_job_dict = defaultdict(list)
self.job_worker_dict = {}
self.master_is_alive = True
def _create_worker_monitor(self, worker_heartbeat_address, worker_address):
"""When a new worker connects to the master, a socket is created to
send heartbeat signals to the worker.
"""
worker_heartbeat_socket = self.ctx.socket(zmq.REQ)
worker_heartbeat_socket.linger = 0
worker_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
worker_heartbeat_socket.connect("tcp://" + worker_heartbeat_address)
connected = True
while connected and self.master_is_alive:
try:
worker_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = worker_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
for job in self.worker_job_dict[worker_address]:
if job in self.job_pool:
self.job_pool.remove(job)
self.job_worker_dict.pop(job)
self.worker_job_dict.pop(worker_address)
self.worker_pool.pop(worker_address)
logger.warning("\n[Master] Cannot connect to the worker " +
"{}. ".format(worker_address) +
"Worker_pool will drop this worker.")
self._print_workers()
connected = False
except zmq.error.ZMQError as e:
break
worker_heartbeat_socket.close(0)
logger.warning("Exit worker monitor from master.")
def _create_client_monitor(self, client_heartbeat_address):
"""when a new client connects to the master, a socket is created to
send heartbeat signals to the client.
"""
client_heartbeat_socket = self.ctx.socket(zmq.REQ)
client_heartbeat_socket.linger = 0
client_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
client_heartbeat_socket.connect("tcp://" + client_heartbeat_address)
self.client_is_alive = True
while self.client_is_alive and self.master_is_alive:
try:
client_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = client_heartbeat_socket.recv_multipart()
except zmq.error.Again as e:
self.client_is_alive = False
logger.warning("[Master] cannot connect to the client " +
"{}. ".format(client_heartbeat_address) +
"Please check if it is still alive.")
self._kill_client_jobs(client_heartbeat_address)
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
logger.warning("Master exits client monitor for {}.\n".format(
client_heartbeat_address))
logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool)))
client_heartbeat_socket.close(0)
def _kill_client_jobs(self, client_address):
"""set timeout in case the worker and client quit at the same time.
"""
jobs = self.client_job_dict[client_address]
for job_address in jobs:
if job_address in self.job_worker_dict:
worker_address = self.job_worker_dict[job_address]
worker_socket = self.worker_pool[worker_address].worker_socket
self.worker_locks[worker_address].acquire()
worker_socket.send_multipart(
[remote_constants.KILLJOB_TAG,
to_byte(job_address)])
try:
_ = worker_socket.recv_multipart()
except zmq.error.Again as e:
logger.warning("Error in recv kill_client_job")
self.worker_locks[worker_address].release()
self.job_worker_dict.pop(job_address)
self.client_job_dict.pop(client_address)
def _print_workers(self):
"""Display `worker_pool` infomation."""
logger.info(
"Master connects to {} workers and have {} vacant CPUs.\n".format(
len(self.worker_pool), len(self.job_pool)))
def _receive_message(self):
"""master node will receive four types of message: (1) worker
connection; (2) worker update; (3) client connection; (4) job
submittion.
"""
message = self.client_socket.recv_multipart()
tag = message[0]
# a new worker connects to the master
if tag == remote_constants.WORKER_CONNECT_TAG:
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
elif tag == remote_constants.WORKER_INITIALIZED_TAG:
worker = pickle.loads(message[1])
worker_heartbeat_address = to_str(message[2])
# maintain job & worker relations
for job_address in worker.job_pool:
self.job_worker_dict[job_address] = worker.address
self.worker_job_dict[worker.address] = worker.job_pool
self.job_pool.extend(worker.job_pool)
# a new socket for submitting job to the worker
worker_socket = self.ctx.socket(zmq.REQ)
worker_socket.linger = 0
worker_socket.setsockopt(zmq.RCVTIMEO, 10000)
worker_socket.connect("tcp://{}".format(worker.address))
worker.worker_socket = worker_socket
self.worker_pool[worker.address] = worker
self.worker_locks[worker.address] = threading.Lock()
logger.info("A new worker {} is added, ".format(worker.address) +
"cluster has {} CPUs.\n".format(len(self.job_pool)))
# a thread for sending heartbeat signals to `worker.address`
thread = threading.Thread(
target=self._create_worker_monitor,
args=(
worker_heartbeat_address,
worker.address,
),
daemon=True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
# a client connects to the master
elif tag == remote_constants.CLIENT_CONNECT_TAG:
client_heartbeat_address = to_str(message[1])
logger.info(
"Client {} is connected.".format(client_heartbeat_address))
thread = threading.Thread(
target=self._create_client_monitor,
args=(client_heartbeat_address, ),
daemon=True)
thread.start()
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
# a client submits a job to the master
elif tag == remote_constants.CLIENT_SUBMIT_TAG:
client_address = to_str(message[1])
done_flag = False
# check available CPU resources
if len(self.job_pool):
logger.info("Submitting job...")
job_address = self.job_pool.pop(0)
worker_address = self.job_worker_dict[job_address]
self.worker_job_dict[worker_address].remove(job_address)
self.client_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(job_address)])
self.client_job_dict[client_address].append(job_address)
self._print_workers()
else:
self.client_socket.send_multipart([remote_constants.CPU_TAG])
# a worker updates
elif tag == remote_constants.NEW_JOB_TAG:
worker_address = to_str(message[1])
new_job_address = to_str(message[2])
killed_job_address = to_str(message[3])
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
logger.info("A worker updated.")
if killed_job_address in self.job_worker_dict:
self.job_worker_dict.pop(killed_job_address)
if killed_job_address in self.worker_job_dict[worker_address]:
self.worker_job_dict[worker_address].remove(killed_job_address)
if killed_job_address in self.job_pool:
self.job_pool.remove(killed_job_address)
# add new job_address to job_pool
self.job_pool.append(new_job_address)
self.job_worker_dict[new_job_address] = worker_address
self.worker_job_dict[worker_address].append(new_job_address)
self._print_workers()
# check before start a worker
elif tag == remote_constants.NORMAL_TAG:
self.client_socket.send_multipart([remote_constants.NORMAL_TAG])
else:
raise NotImplementedError()
def exit(self):
self.master_is_alive = False
self.ctx.destroy()
def run(self):
"""An infinite loop waiting for messages from the workers and
clients.
Master node will receive four types of messages:
1. A new worker connects to the master node.
2. A connected worker sending new job address after it kills an old
job.
3. A new client connects to the master node.
4. A connected client submits a job after a remote object is created.
"""
while self.master_is_alive:
try:
self._receive_message()
except zmq.error.ContextTerminated as e:
pass
logger.warning("[Master] Exit master.")
...@@ -12,8 +12,21 @@ ...@@ -12,8 +12,21 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
CPU_TAG = b'[CPU]'
CONNECT_TAG = b'[CONNECT]' CONNECT_TAG = b'[CONNECT]'
HEARTBEAT_TAG = b'[HEARTBEAT]' HEARTBEAT_TAG = b'[HEARTBEAT]'
KILLJOB_TAG = b'[KILLJOB]'
WORKER_CONNECT_TAG = b'[WORKER_CONNECT]'
WORKER_INITIALIZED_TAG = b'[WORKER_INITIALIZED]'
CLIENT_CONNECT_TAG = b'[CLIENT_CONNECT]'
CLIENT_SUBMIT_TAG = b'[CLIENT_SUBMIT]'
SEND_FILE_TAG = b'[SEND_FILE]'
SUBMIT_JOB_TAG = b'[SUBMIT_JOB]'
NEW_JOB_TAG = b'[NEW_JOB]'
INIT_OBJECT_TAG = b'[INIT_OBJECT]'
CALL_TAG = b'[CALL]'
EXCEPTION_TAG = b'[EXCEPTION]' EXCEPTION_TAG = b'[EXCEPTION]'
ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]' ATTRIBUTE_EXCEPTION_TAG = b'[ATTRIBUTE_EXCEPTION]'
...@@ -24,3 +37,5 @@ NORMAL_TAG = b'[NORMAL]' ...@@ -24,3 +37,5 @@ NORMAL_TAG = b'[NORMAL]'
# interval of heartbeat mechanism in the unit of second # interval of heartbeat mechanism in the unit of second
HEARTBEAT_INTERVAL_S = 10 HEARTBEAT_INTERVAL_S = 10
HEARTBEAT_TIMEOUT_S = 10
HEARTBEAT_RCVTIMEO_S = HEARTBEAT_INTERVAL_S + HEARTBEAT_TIMEOUT_S * 2
...@@ -12,235 +12,174 @@ ...@@ -12,235 +12,174 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np import cloudpickle
import pyarrow import os
import threading import threading
import time import time
import traceback
import zmq import zmq
from parl.remote import remote_constants import numpy as np
from parl.utils import get_ip_address, logger, to_str, to_byte from parl.utils import get_ip_address, logger, to_str, to_byte
from parl.utils.exceptions import SerializeError, DeserializeError from parl.utils.communication import loads_argument, loads_return,\
from parl.utils.communication import loads_argument, dumps_return dumps_argument, dumps_return
""" from parl.remote import remote_constants
Three steps to create a remote class: from parl.remote.exceptions import RemoteError, RemoteAttributeError,\
1. add a decroator(@parl.remote_class) before the definition of the class; RemoteDeserializeError, RemoteSerializeError, ResourceError
2. create an instance of remote class; from parl.remote.client import get_global_client
3. call function `as_remote` with server_ip and server_port.
@parl.remote_class
Class Simulator(object):
...
sim = Simulator() def remote_class(cls):
sim.as_remote(server_ip='172.18.202.45', server_port=8001) """A Python decorator that enables a class to run all its functions
remotely.
""" Each instance of the remote class can be seemed as a task submitted
to the cluster by the global client, which is created automatically
when we call parl.connect(master_address). After global client
submits the task, the master node will send an available job address
to this remote instance. Then the remote object will send local python
files, class definition and initialization arguments to the related job.
In this way, we can run distributed applications easily and efficiently.
def remote_class(cls): .. code-block:: python
class ClientWrapper(object):
"""
Wrapper for remote class in client side.
when as_remote function called, the object initialized in the client can
handle function call from server.
"""
def __init__(self, *args, **kwargs): @remote_class
""" class Actor(object):
Args: def __init__(self, x):
args, kwargs: arguments for the initialisation of the unwrapped class. self.x = x
"""
self.unwrapped = cls(*args, **kwargs)
self.zmq_context = None def step(self):
self.poller = None self.x += 1
return self.x
# socket for connecting server and telling ip and port of client to server actor = Actor()
self.connect_socket = None actor.step()
# socket for handle function call from server side
self.reply_socket = None
def _create_reply_socket(self, remote_ip, remote_port): Returns:
""" A remote wrapper for the remote class.
In fact, we also have a socket server in client side. This server keeps running
and waits for requests (e.g. call a function) from server side.
"""
if remote_ip is None:
remote_ip = get_ip_address()
self.zmq_context = zmq.Context()
socket = self.zmq_context.socket(zmq.REP)
if remote_port is None:
try:
remote_port = socket.bind_to_random_port(addr="tcp://*")
except zmq.ZMQBindError:
logger.error(
'Can not bind to a random port, please set remote_port manually.'
)
sys.exit(1)
else:
socket.bind("tcp://*:{}".format(remote_port))
return socket, remote_ip, remote_port Raises:
Exception: An exception is raised if the client is not created
by `parl.connect(master_address)` beforehand.
"""
def _connect_server(self, server_ip, server_port, remote_ip, class RemoteWrapper(object):
remote_port): """
""" Wrapper for remote class in client side.
Create the connection between client side and server side. """
def __init__(self, *args, **kwargs):
"""
Args: Args:
server_ip(str): the ip of the server. args, kwargs: arguments for the initialization of the unwrapped
server_port(int): the connection port of the server. class.
remote_ip: the ip of the client itself.
remote_port: the port of the client itself,
which used to create reply socket.
""" """
self.reply_socket, local_ip, local_port = self._create_reply_socket( self.GLOBAL_CLIENT = get_global_client()
remote_ip, remote_port)
self.reply_socket.linger = 0
socket = self.zmq_context.socket(zmq.REQ) self.ctx = self.GLOBAL_CLIENT.ctx
socket.connect("tcp://{}:{}".format(server_ip, server_port))
logger.info("connecting {}:{}".format(server_ip, server_port)) # GLOBAL_CLIENT will set `master_is_alive` to False when hearbeat
# finds the master is dead.
if self.GLOBAL_CLIENT.master_is_alive:
job_address = self.request_cpu_resource(self.GLOBAL_CLIENT)
else:
raise Exception("Can not submit job to the master. "
"Please check if master is still alive.")
if job_address is None:
raise ResourceError("Cannot submit the job to the master. "
"Please add more CPU resources to the "
"master or try again later.")
self.internal_lock = threading.Lock()
# Send actor commands like `init` and `call` to the job.
self.job_socket = self.ctx.socket(zmq.REQ)
self.job_socket.linger = 0
self.job_socket.connect("tcp://{}".format(job_address))
self.job_address = job_address
self.send_file(self.job_socket)
try:
self.job_socket.send_multipart([
remote_constants.INIT_OBJECT_TAG,
cloudpickle.dumps(cls),
cloudpickle.dumps([args, kwargs])
])
_ = self.job_socket.recv_multipart()
except zmq.error.Again as e:
logger.error("Job socket failed.")
logger.info("[connect_job] job_address:{}".format(job_address))
def __del__(self):
"""Delete the remote class object and release remote resources."""
self.job_socket.send_multipart([remote_constants.KILLJOB_TAG])
_ = self.job_socket.recv_multipart()
self.job_socket.close(0)
def send_file(self, socket):
try:
socket.send_multipart([
remote_constants.SEND_FILE_TAG, self.GLOBAL_CLIENT.pyfiles
])
_ = socket.recv_multipart()
except zmq.error.Again as e:
logger.error("Send python files failed.")
def request_cpu_resource(self, global_client):
"""Try to request cpu resource for 1 second/time for 300 times."""
cnt = 300
while cnt > 0:
job_address = global_client.submit_job()
if job_address is not None:
return job_address
if cnt % 30 == 0:
logger.warning("No vacant cpu resources at present, "
"will try {} times later.".format(cnt))
cnt -= 1
time.sleep(1)
return None
client_addr = '{}:{}'.format(local_ip, local_port) def __getattr__(self, attr):
socket.send_multipart( """Call the function of the unwrapped class."""
[remote_constants.CONNECT_TAG,
to_byte(client_addr)])
message = socket.recv_multipart() def wrapper(*args, **kwargs):
self.client_id = message[1] self.internal_lock.acquire()
logger.info("connect server done, client_id: {}".format( data = dumps_argument(*args, **kwargs)
self.client_id))
self.connect_socket = socket
self.connect_socket.linger = 0
def _exit_remote(self): self.job_socket.send_multipart(
self.poller.unregister(self.connect_socket) [remote_constants.CALL_TAG,
to_byte(attr), data])
self.connect_socket.close() message = self.job_socket.recv_multipart()
self.reply_socket.close() tag = message[0]
# The program may hang when destroying zmq context manually. if tag == remote_constants.NORMAL_TAG:
# It will be destroyed automatically by the garbage collection mechanism of python, ret = loads_return(message[1])
# though it may raise some exceptions in C++.
#self.zmq_context.destroy() elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteError(attr, error_str)
def _heartbeat_loop(self): elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
""" error_str = to_str(message[1])
Periodically detect whether the server is alive or not raise RemoteAttributeError(attr, error_str)
"""
self.poller = zmq.Poller()
self.poller.register(self.connect_socket, zmq.POLLIN)
while True: elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
self.connect_socket.send_multipart( error_str = to_str(message[1])
[remote_constants.HEARTBEAT_TAG, self.client_id]) raise RemoteSerializeError(attr, error_str)
# wait for at most 10s to receive response elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
socks = dict(self.poller.poll(10000)) error_str = to_str(message[1])
raise RemoteDeserializeError(attr, error_str)
if socks.get(self.connect_socket) == zmq.POLLIN:
_ = self.connect_socket.recv_multipart()
else: else:
logger.warning( raise NotImplementedError()
'[HeartBeat] Server no response, will exit now!')
self._exit_remote()
break
# HeartBeat interval 10s
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
def __getattr__(self, attr):
"""
Call the function of the unwrapped class.
"""
def wrapper(*args, **kwargs): self.internal_lock.release()
return getattr(self.unwrapped, attr)(*args, **kwargs) return ret
return wrapper return wrapper
def _reply_loop(self): return RemoteWrapper
while True:
message = self.reply_socket.recv_multipart()
try:
function_name = to_str(message[1])
data = message[2]
args, kwargs = loads_argument(data)
ret = getattr(self.unwrapped, function_name)(*args,
**kwargs)
ret = dumps_return(ret)
except Exception as e:
error_str = str(e)
logger.error(error_str)
if type(e) == AttributeError:
self.reply_socket.send_multipart([
remote_constants.ATTRIBUTE_EXCEPTION_TAG,
to_byte(error_str)
])
elif type(e) == SerializeError:
self.reply_socket.send_multipart([
remote_constants.SERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
elif type(e) == DeserializeError:
self.reply_socket.send_multipart([
remote_constants.DESERIALIZE_EXCEPTION_TAG,
to_byte(error_str)
])
else:
traceback_str = str(traceback.format_exc())
logger.error('traceback:\n{}'.format(traceback_str))
self.reply_socket.send_multipart([
remote_constants.EXCEPTION_TAG,
to_byte(error_str + '\ntraceback:\n' +
traceback_str)
])
continue
self.reply_socket.send_multipart(
[remote_constants.NORMAL_TAG, ret])
def as_remote(self,
server_ip,
server_port,
remote_ip=None,
remote_port=None):
"""
Client will connect server and wait for function calls from server side.
Args:
server_ip(str): server's ip
server_port(int): server's port
remote_ip: the ip of the client itself.
remote_port: the port of the client itself,
which used to create reply socket.
"""
self._connect_server(server_ip, server_port, remote_ip,
remote_port)
reply_thread = threading.Thread(target=self._reply_loop)
reply_thread.setDaemon(True)
reply_thread.start()
self._heartbeat_loop()
def remote_closed(self):
"""
Check whether as_remote mode is closed
"""
assert self.reply_socket is not None, 'as_remote function should be called first!'
assert self.connect_socket is not None, 'as_remote function should be called first!'
return self.reply_socket.closed and self.connect_socket.closed
return ClientWrapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from six.moves import queue
import threading
import time
import zmq
from parl.utils import logger, to_byte, to_str
from parl.remote import remote_constants
from parl.remote.remote_object import RemoteObject
"""
Two steps to build the communication with remote clients:
1. Create a RemoteManager;
2. Get remote objects by calling the function get_remote.
```python
remote_manager = RemoteManager(port=[port])
remote_obj = remote_manager.get_remote()
```
"""
class RemoteManager(object):
"""
Base class for network communcation.
"""
def __init__(self, port):
"""
Args:
port(int): a local port used for connections from remote clients.
"""
self.zmq_context = zmq.Context()
socket = self.zmq_context.socket(zmq.REP)
socket.bind("tcp://*:{}".format(port))
self.socket = socket
self.socket.linger = 0
self.remote_pool = queue.Queue()
self.remote_latest_timestamp = {}
t = threading.Thread(target=self._wait_for_connection)
t.setDaemon(True) # The thread will exit when main thread exited
t.start()
t = threading.Thread(target=self._check_remote_status)
t.setDaemon(True) # The thread will exit when main thread exited
t.start()
def _wait_for_connection(self):
"""
A never-ending function keeps waiting for the connections from remote client.
It will put an available remote object in an internel pool, and remote object
can be obtained by calling `get_remote`.
Note that this function has been called inside the `__init__` function.
"""
remote_id = 0
while True:
try:
message = self.socket.recv_multipart()
tag = message[0]
if tag == remote_constants.CONNECT_TAG:
self.socket.send_multipart([
remote_constants.NORMAL_TAG,
to_byte('{}'.format(remote_id))
])
remote_client_address = to_str(message[1])
remote_obj = RemoteObject(remote_client_address,
self.zmq_context)
self.remote_latest_timestamp[to_byte(
str(remote_id))] = time.time()
logger.info('[RemoteManager] Added a new remote object.')
self.remote_pool.put(remote_obj)
remote_id += 1
elif tag == remote_constants.HEARTBEAT_TAG:
self.remote_latest_timestamp[message[1]] = time.time()
self.socket.send_multipart(
[remote_constants.NORMAL_TAG, b'Server is alive.'])
else:
raise NotImplementedError()
except zmq.ZMQError:
logger.warning('Zmq error, exiting server.')
break
def _check_remote_status(self):
while True:
for remote_id in list(self.remote_latest_timestamp.keys()):
if time.time() - self.remote_latest_timestamp[
remote_id] > 3 * remote_constants.HEARTBEAT_INTERVAL_S:
logger.error(
'Remote object {} is lost, please check if anything wrong happens in the remote client'
.format(remote_id))
self.remote_latest_timestamp.pop(remote_id)
time.sleep(3 * remote_constants.HEARTBEAT_INTERVAL_S)
def get_remote(self):
"""
A blocking function to obtain a remote object.
Returns:
RemoteObject
"""
return self.remote_pool.get()
def close(self):
"""
Close RemoteManager.
"""
self.zmq_context.destroy()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
import zmq
from parl.remote import remote_constants
from parl.remote.exceptions import *
from parl.utils import logger, to_str, to_byte
from parl.utils.communication import dumps_argument, loads_return
class RemoteObject(object):
"""
Provides interface to call functions of object in remote client.
"""
def __init__(self, remote_client_address, zmq_context=None):
"""
Args:
remote_client_address: address(ip:port) of remote client
zmq_context: zmq.Context()
"""
if zmq_context is None:
self.zmq_context = zmq.Context()
else:
self.zmq_context = zmq_context
# socket for sending function call to remote object and receiving result
self.command_socket = None
# lock for thread safety
self.internal_lock = threading.Lock()
self._connect_remote_client(remote_client_address)
def _connect_remote_client(self, remote_client_address):
"""
Build connection with the remote client to send function call.
"""
socket = self.zmq_context.socket(zmq.REQ)
logger.info("[connect_remote_client] client_address:{}".format(
remote_client_address))
socket.connect("tcp://{}".format(remote_client_address))
self.command_socket = socket
self.command_socket.linger = 0
def __getattr__(self, attr):
"""
Provides interface to call functions of object in remote client.
1. send fucntion name and packed auguments to remote client;
2. remote clinet execute the function of the object really;
3. receive function return from remote client.
Args:
attr(str): a function name specify which function to run.
"""
def wrapper(*args, **kwargs):
self.internal_lock.acquire()
data = dumps_argument(*args, **kwargs)
self.command_socket.send_multipart(
[remote_constants.NORMAL_TAG,
to_byte(attr), data])
message = self.command_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.NORMAL_TAG:
ret = loads_return(message[1])
elif tag == remote_constants.EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteError(attr, error_str)
elif tag == remote_constants.ATTRIBUTE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteAttributeError(attr, error_str)
elif tag == remote_constants.SERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteSerializeError(attr, error_str)
elif tag == remote_constants.DESERIALIZE_EXCEPTION_TAG:
error_str = to_str(message[1])
raise RemoteDeserializeError(attr, error_str)
else:
raise NotImplementedError()
self.internal_lock.release()
return ret
return wrapper
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import click
import locale
import os
import subprocess
import threading
import warnings
from multiprocessing import Process
# A flag to mark if parl is started from a command line
os.environ['XPARL'] = 'True'
# Solve `Click will abort further execution because Python 3 was configured
# to use ASCII as encoding for the environment` error.
locale.setlocale(locale.LC_ALL, "en_US.UTF-8")
warnings.simplefilter("ignore", ResourceWarning)
def is_port_in_use(port):
""" Check if a port is used.
True if the port is not available. Otherwise, this port can be used for
connection.
"""
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', int(port))) == 0
def is_master_started(address):
import zmq
ctx = zmq.Context()
socket = ctx.socket(zmq.REQ)
socket.linger = 0
socket.setsockopt(zmq.RCVTIMEO, 500)
socket.connect("tcp://{}".format(address))
socket.send_multipart([b'[NORMAL]'])
try:
_ = socket.recv_multipart()
socket.close(0)
return True
except zmq.error.Again as e:
socket.close(0)
return False
@click.group()
def cli():
pass
@click.command("start", short_help="Start a master node.")
@click.option("--port", help="The port to bind to.", type=str, required=True)
@click.option(
"--cpu_num",
type=int,
help="Set number of cpu manually. If not set, it will use all "
"cpus of this machine.")
def start_master(port, cpu_num):
if is_port_in_use(port):
raise Exception(
"The master address localhost:{} already in use.".format(port))
cpu_num = str(cpu_num) if cpu_num else ''
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "master",
"--port", port
]
p = subprocess.Popen(command)
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker",
"--address", "localhost:" + str(port), "--cpu_num",
str(cpu_num)
]
p = subprocess.Popen(command)
@click.command("connect", short_help="Start a worker node.")
@click.option(
"--address", help="IP address of the master node.", required=True)
@click.option(
"--cpu_num",
type=int,
help="Set number of cpu manually. If not set, it will use all "
"cpus of this machine.")
def start_worker(address, cpu_num):
if not is_master_started(address):
raise Exception("Worker can not connect to the master node, " +
"please check if the input address {} ".format(
address) + "is correct.")
cpu_num = str(cpu_num) if cpu_num else ''
command = [
"python", "{}/start.py".format(__file__[:-11]), "--name", "worker",
"--address", address, "--cpu_num",
str(cpu_num)
]
p = subprocess.Popen(command)
@click.command("stop", help="Exit the cluster.")
def stop():
command = ("pkill -f remote/start.py")
subprocess.call([command], shell=True)
command = ("pkill -f job.py")
p = subprocess.call([command], shell=True)
cli.add_command(start_worker)
cli.add_command(start_master)
cli.add_command(stop)
def main():
return cli()
if __name__ == "__main__":
main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
import threading
from parl.remote import Master, Worker
def main(args):
"""Start a master or a worker through:
1. xparl start --port 1234
2. xparl connect --address localhost:1234 --cpu_num 8
"""
if args.name == 'master':
port = args.port
master = Master(port)
master.run()
elif args.name == 'worker':
address = args.address
cpu_num = int(args.cpu_num) if args.cpu_num else None
worker = Worker(address, cpu_num)
worker.run()
else:
raise NotImplementedError
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
'--name', default='master', type=str, help='master/worker')
parser.add_argument('--port', default='1234', type=str)
parser.add_argument('--address', default='localhost:1234', type=str)
parser.add_argument('--cpu_num', default='', type=str)
args = parser.parse_args()
main(args)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import unittest
@parl.remote_class
class Simulator:
def __init__(self, arg1, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
class TestRemoteDecorator(unittest.TestCase):
def test_instance_in_local(self):
local_sim = Simulator(1, 2)
self.assertEqual(local_sim.get_arg1(), 1)
self.assertEqual(local_sim.get_arg2(), 2)
local_sim.set_arg1(3)
local_sim.set_arg2(4)
self.assertEqual(local_sim.get_arg1(), 3)
self.assertEqual(local_sim.get_arg2(), 4)
def test_instance_in_local_with_wrong_getattr_get_variable(self):
local_sim = Simulator(1, 2)
try:
local_sim.get_arg3()
except AttributeError:
return
assert False # This line should not be executed.
def test_instance_in_local_with_wrong_getattr_set_variable(self):
local_sim = Simulator(1, 2)
try:
local_sim.set_arg3(3)
except AttributeError:
return
assert False # This line should not be executed.
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import parl
import threading
import unittest
from parl.remote import *
class UnableSerializeObject(object):
def __init__(self):
# threading.Lock() can not be serialized
self.lock = threading.Lock()
@parl.remote_class
class Simulator:
def __init__(self, arg1, arg2=None):
self.arg1 = arg1
self.arg2 = arg2
def get_arg1(self):
return self.arg1
def get_arg2(self):
return self.arg2
def set_arg1(self, value):
self.arg1 = value
def set_arg2(self, value):
self.arg2 = value
def get_unable_serialize_object(self):
return UnableSerializeObject()
def add_one(self, value):
value += 1
return value
def will_raise_exeception_func(self):
x = 1 / 0
class TestRemote(unittest.TestCase):
def _setUp(self, server_port):
self.sim = Simulator(1, arg2=2)
# run client in a new thread to fake a remote client
self.client_thread = threading.Thread(
target=self.sim.as_remote, args=(
'localhost',
server_port,
))
self.client_thread.setDaemon(True)
self.client_thread.start()
self.remote_manager = RemoteManager(port=server_port)
def test_remote_object(self):
server_port = 17770
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
self.assertEqual(remote_sim.get_arg1(), 1)
self.assertEqual(remote_sim.get_arg2(), 2)
ret = remote_sim.set_arg1(3)
self.assertIsNone(ret)
ret = remote_sim.set_arg2(4)
self.assertIsNone(ret)
self.assertEqual(remote_sim.get_arg1(), 3)
self.assertEqual(remote_sim.get_arg2(), 4)
def test_remote_object_with_wrong_getattr_get_variable(self):
server_port = 17771
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.get_arg3()
except RemoteAttributeError as e:
logger.info('Expected exception: {}'.format(e))
# expected
return
assert False
def test_remote_object_with_wrong_getattr_set_variable(self):
server_port = 17772
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg3(3)
except RemoteAttributeError as e:
logger.info('Expected exception: {}'.format(e))
# expected
return
assert False
def test_remote_object_with_wrong_argument(self):
server_port = 17773
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg1(wrong_arg=1)
except RemoteError as e:
logger.info('Expected exception: {}'.format(e))
# expected
return
assert False
def test_remote_object_with_unable_serialize_argument(self):
server_port = 17774
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.set_arg1(UnableSerializeObject())
except SerializeError as e:
logger.info('Expected exception: {}'.format(e))
# expected
return
assert False
def test_remote_object_with_unable_serialize_return(self):
server_port = 17775
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.get_unable_serialize_object()
except RemoteSerializeError as e:
# expected
logger.info('Expected exception: {}'.format(e))
return
assert False
def test_multi_remote_object(self):
server_port = 17776
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote, args=(
'localhost',
server_port,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
def test_multi_remote_object_with_one_failed(self):
server_port = 17777
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote, args=(
'localhost',
server_port,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
try:
# make remote sim1 failed
remote_sim1.get_arg3()
except:
pass
self.assertEqual(remote_sim2.get_arg1(), 11)
# Todo(@zenghongsheng):
# zmq will raise unexpected C++ exception when closing context,
# remove this unittest for now.
#def test_heartbeat_after_server_closed(self):
# server_port = 17778
# self._setUp(server_port)
# remote_sim = self.remote_manager.get_remote()
# time.sleep(1)
# self.remote_manager.close()
# # heartbeat interval (10s) + max waiting reply (10s)
# time.sleep(20)
# logger.info('check self.sim.remote_closed')
# self.assertTrue(self.sim.remote_closed())
def test_set_client_ip_port_manually(self):
server_port = 17779
self._setUp(server_port)
time.sleep(1)
# run second client
sim2 = Simulator(11, arg2=22)
client_thread2 = threading.Thread(
target=sim2.as_remote,
args=(
'localhost',
server_port,
'localhost',
6666,
))
client_thread2.setDaemon(True)
client_thread2.start()
time.sleep(1)
remote_sim1 = self.remote_manager.get_remote()
remote_sim2 = self.remote_manager.get_remote()
self.assertEqual(remote_sim1.get_arg1(), 1)
self.assertEqual(remote_sim2.get_arg1(), 11)
def test_thread_safe_of_remote_module(self):
server_port = 17780
self._setUp(server_port)
time.sleep(1)
thread_num = 10
for _ in range(thread_num):
# run clients in backend
sim = Simulator(11, arg2=22)
client_thread = threading.Thread(
target=sim.as_remote, args=(
'localhost',
server_port,
))
client_thread.setDaemon(True)
client_thread.start()
time.sleep(1)
threads = []
for _ in range(thread_num):
remote_sim = self.remote_manager.get_remote()
t = threading.Thread(
target=self._run_remote_add, args=(remote_sim, ))
t.start()
threads.append(t)
for t in threads:
t.join()
def test_remote_object_with_call_raise_exception_function(self):
server_port = 17781
self._setUp(server_port)
remote_sim = self.remote_manager.get_remote()
try:
remote_sim.will_raise_exeception_func()
except RemoteError as e:
assert 'Traceback (most recent call last)' in str(e)
logger.info('Expected exception: {}'.format(e))
# expected
return
assert False
def _run_remote_add(self, remote_sim):
value = 0
for i in range(1000):
value = remote_sim.add_one(value)
assert value == i + 1
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import cloudpickle
import multiprocessing
import os
import subprocess
import sys
import time
import threading
import zmq
from parl.utils import get_ip_address, to_byte, to_str, logger
from parl.remote import remote_constants
class WorkerInfo(object):
"""A WorkerInfo object records the computation resources of a worker.
"""
def __init__(self, address, cpu_num, job_pool):
self.address = address
self.cpu_num = cpu_num
self.job_pool = job_pool
self.worker_socket = None
class Worker(object):
"""Worker provides the cpu computation resources for the cluster.
A worker node is connected to the master node and will send its
computation resources information to the master node. When a worker
node is created, it will start `cpu_num` empty jobs and these jobs'
ip addresses will be send to the master node. Further, when an old
job is killed, worker will start a new job and send the new job ip
address to the master node.
To start a worker, we use the following xparl command line api:
.. code-block:: python
xparl connect --address localhost:1234 --cpu_num 8
Attributes:
job_pid (dict): A dict of subprocess id and its address.
master_address (str): Master's ip address.
request_master_socket (zmq.Context.socket): A socket which sends job
address to the master node.
reply_master_socket (zmq.Context.socket): A socket which accepts
submitted job from master
node.
reply_job_socket (zmq.Context.socket): A socket which receives
job_address from the job.
Args:
master_address (str): IP address of the master node.
cpu_num (int): Number of cpu to be used on the worker.
"""
def __init__(self, master_address, cpu_num=None):
self.lock = threading.Lock()
self.heartbeat_socket_initialized = threading.Event()
self.ctx = zmq.Context.instance()
self.job_pid = {}
self.master_address = master_address
self.master_is_alive = True
self.worker_is_alive = True
self._set_cpu_num(cpu_num)
self._create_sockets()
self._create_worker()
def _set_cpu_num(self, cpu_num=None):
"""set useable cpu number for worker"""
if cpu_num is not None:
assert isinstance(
cpu_num, int
), "cpu_num should be INT type, please check the input type."
self.cpu_num = cpu_num
else:
self.cpu_num = multiprocessing.cpu_count()
def _create_sockets(self):
""" Each worker has three sockets at start:
(1) request_master_socket: sends job address to master node.
(2) reply_master_socket: accepts submitted job from master node.
(3) reply_job_socket: receives job_address from subprocess.
When a job is start, a new heartbeat socket is created to receive
heartbeat signal from the job.
"""
# request_master_socket: sends job address to master
self.request_master_socket = self.ctx.socket(zmq.REQ)
self.request_master_socket.linger = 0
# wait for 0.5 second to check whether master is started
self.request_master_socket.setsockopt(zmq.RCVTIMEO, 500)
self.request_master_socket.connect("tcp://" + self.master_address)
# reply_master_socket: receives submitted job from master
self.reply_master_socket = self.ctx.socket(zmq.REP)
self.reply_master_socket.linger = 0
self.worker_ip = get_ip_address()
reply_master_port = self.reply_master_socket.bind_to_random_port(
"tcp://*")
self.reply_master_address = "{}:{}".format(self.worker_ip,
reply_master_port)
logger.set_dir(
os.path.expanduser('~/.parl_data/worker/{}'.format(
self.reply_master_address)))
# reply_job_socket: receives job_address from subprocess
self.reply_job_socket = self.ctx.socket(zmq.REP)
self.reply_job_socket.linger = 0
reply_job_port = self.reply_job_socket.bind_to_random_port("tcp://*")
self.reply_job_address = "{}:{}".format(self.worker_ip, reply_job_port)
def _create_worker(self):
"""create a WorkerInfo instance and send it to the master."""
try:
self.request_master_socket.send_multipart(
[remote_constants.WORKER_CONNECT_TAG])
_ = self.request_master_socket.recv_multipart()
except zmq.error.Again as e:
logger.error("Can not connect to the master, "
"please check if master is started.")
self.master_is_alive = False
return
self._init_jobs()
self.request_master_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
self.worker = WorkerInfo(self.reply_master_address, self.cpu_num,
list(self.job_pid.keys()))
reply_thread = threading.Thread(
target=self._reply_heartbeat,
args=("master {}".format(self.master_address), ),
daemon=True)
reply_thread.start()
self.heartbeat_socket_initialized.wait()
self.request_master_socket.send_multipart([
remote_constants.WORKER_INITIALIZED_TAG,
cloudpickle.dumps(self.worker),
to_byte(self.heartbeat_master_address)
])
_ = self.request_master_socket.recv_multipart()
def _init_job(self):
"""Create one job."""
command = [
"python", "{}/job.py".format(__file__[:-10]), "--worker_address",
self.reply_job_address
]
with open(os.devnull, "w") as null:
pid = subprocess.Popen(command, stdout=null, stderr=null)
self.lock.acquire()
job_message = self.reply_job_socket.recv_multipart()
self.reply_job_socket.send_multipart([remote_constants.NORMAL_TAG])
job_address = to_str(job_message[1])
heartbeat_job_address = to_str(job_message[2])
self.job_pid[job_address] = pid
self.lock.release()
# a thread for sending heartbeat signals to job
thread = threading.Thread(
target=self._create_job_monitor,
args=(
job_address,
heartbeat_job_address,
),
daemon=True)
thread.start()
return job_address
def _init_jobs(self):
"""Create cpu_num jobs when the worker is created."""
job_threads = []
for _ in range(self.cpu_num):
t = threading.Thread(target=self._init_job, daemon=True)
t.start()
job_threads.append(t)
for th in job_threads:
th.join()
def _kill_job(self, job_address):
"""kill problematic job process and update worker information"""
if job_address in self.job_pid:
self.job_pid[job_address].kill()
self.job_pid.pop(job_address)
logger.warning("Worker kills job process {},".format(job_address))
# When a old job is killed, the worker will create a new job.
if self.master_is_alive:
new_job_address = self._init_job()
self.lock.acquire()
self.request_master_socket.send_multipart([
remote_constants.NEW_JOB_TAG,
to_byte(self.reply_master_address),
to_byte(new_job_address),
to_byte(job_address)
])
_ = self.request_master_socket.recv_multipart()
self.lock.release()
def _create_job_monitor(self, job_address, heartbeat_job_address):
"""Sending heartbeat signals to check target's status"""
# job_heartbeat_socket: sends heartbeat signal to job
job_heartbeat_socket = self.ctx.socket(zmq.REQ)
job_heartbeat_socket.linger = 0
job_heartbeat_socket.setsockopt(
zmq.RCVTIMEO, remote_constants.HEARTBEAT_TIMEOUT_S * 1000)
job_heartbeat_socket.connect("tcp://" + heartbeat_job_address)
job_is_alive = True
while job_is_alive and self.master_is_alive:
try:
job_heartbeat_socket.send_multipart(
[remote_constants.HEARTBEAT_TAG])
_ = job_heartbeat_socket.recv_multipart()
time.sleep(remote_constants.HEARTBEAT_INTERVAL_S)
except zmq.error.Again as e:
job_is_alive = False
if job_address in self.job_pid:
logger.warning("[Worker] No heartbeat reply from the job, "
"will kill {}.".format(job_address))
self._kill_job(job_address)
except zmq.error.ZMQError as e:
break
job_heartbeat_socket.close(0)
def _reply_heartbeat(self, target):
"""Worker will kill its jobs when it lost connection with the master.
"""
socket = self.ctx.socket(zmq.REP)
socket.linger = 0
socket.setsockopt(zmq.RCVTIMEO,
remote_constants.HEARTBEAT_RCVTIMEO_S * 1000)
heartbeat_master_port =\
socket.bind_to_random_port("tcp://*")
self.heartbeat_master_address = "{}:{}".format(self.worker_ip,
heartbeat_master_port)
self.heartbeat_socket_initialized.set()
logger.info("[Worker] Connect to the master node successfully. "
"({} CPUs)".format(self.cpu_num))
while self.master_is_alive:
try:
message = socket.recv_multipart()
socket.send_multipart([remote_constants.HEARTBEAT_TAG])
except zmq.error.Again as e:
self.master_is_alive = False
for job_address in list(self.job_pid.keys()):
self._kill_job(job_address)
except zmq.error.ContextTerminated as e:
break
socket.close(0)
logger.warning("Worker exit replying heartbeat for master.")
if self.worker_is_alive:
self.exit()
def exit(self):
"""Exit all zmq sockets related to the worker."""
self.worker_is_alive = False
self.ctx.destroy()
def run(self):
"""An infinite loop waiting for killing job commands from
the mater node.
After creating `cpu_num` jobs and sending job addresses to the master
node, a worker will keep waiting for killing job commands from master
node to release computation resources occupied by a dead client. Then
the worker will kill the jobs related to the dead client and create
new jobs and update job addresses to the master node.
"""
while self.master_is_alive and self.worker_is_alive:
try:
message = self.reply_master_socket.recv_multipart()
tag = message[0]
if tag == remote_constants.KILLJOB_TAG:
job_address = to_str(message[1])
self.reply_master_socket.send_multipart(
[remote_constants.NORMAL_TAG])
self._kill_job(job_address)
else:
raise NotImplementedError
except zmq.error.ZMQError as e:
self.worker_is_alive = False
logger.warning("[Worker] Exit Worker {}.".format(
self.reply_master_address))
...@@ -76,12 +76,34 @@ def _getlogger(): ...@@ -76,12 +76,34 @@ def _getlogger():
logger = logging.getLogger('PARL') logger = logging.getLogger('PARL')
logger.propagate = False logger.propagate = False
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)
handler = logging.StreamHandler(sys.stdout) if 'XPARL' not in os.environ:
handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S')) handler = logging.StreamHandler(sys.stdout)
logger.addHandler(handler) handler.setFormatter(_Formatter(datefmt='%m-%d %H:%M:%S'))
logger.addHandler(handler)
return logger return logger
def create_file_after_first_call(func_name):
def call(*args, **kwargs):
global _logger
if LOG_DIR is None:
basename = os.path.basename(mod.__file__)
if basename.rfind('.') == -1:
basename = basename
else:
basename = basename[:basename.rfind('.')]
auto_dirname = os.path.join('log_dir', basename)
shutil.rmtree(auto_dirname, ignore_errors=True)
set_dir(auto_dirname)
func = getattr(_logger, func_name)
func(*args, **kwargs)
return call
_logger = _getlogger() _logger = _getlogger()
_LOGGING_METHOD = [ _LOGGING_METHOD = [
'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug', 'info', 'warning', 'error', 'critical', 'warn', 'exception', 'debug',
...@@ -90,8 +112,9 @@ _LOGGING_METHOD = [ ...@@ -90,8 +112,9 @@ _LOGGING_METHOD = [
# export logger functions # export logger functions
for func in _LOGGING_METHOD: for func in _LOGGING_METHOD:
locals()[func] = getattr(_logger, func) locals()[func] = create_file_after_first_call(func)
__all__.append(func) __all__.append(func)
# export Level information # export Level information
_LOGGING_LEVEL = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'] _LOGGING_LEVEL = ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
for level in _LOGGING_LEVEL: for level in _LOGGING_LEVEL:
...@@ -100,7 +123,7 @@ for level in _LOGGING_LEVEL: ...@@ -100,7 +123,7 @@ for level in _LOGGING_LEVEL:
def _set_file(path): def _set_file(path):
global _FILE_HANDLER global _FILE_HANDLER, _logger
if os.path.isfile(path): if os.path.isfile(path):
try: try:
os.remove(path) os.remove(path)
...@@ -114,16 +137,19 @@ def _set_file(path): ...@@ -114,16 +137,19 @@ def _set_file(path):
def set_level(level): def set_level(level):
global _logger, LOG_DIR
# To set level, need create new handler # To set level, need create new handler
set_dir(get_dir()) if LOG_DIR is not None:
set_dir(get_dir())
_logger.setLevel(level) _logger.setLevel(level)
def set_dir(dirname): def set_dir(dirname):
global LOG_DIR, _FILE_HANDLER global LOG_DIR, _FILE_HANDLER, _logger
if _FILE_HANDLER: if _FILE_HANDLER:
# unload and close the old file handler, so that we may safely delete the logger directory # unload and close the old file handler, so that we may safely delete the logger directory
_logger.removeHandler(_FILE_HANDLER) _logger.removeHandler(_FILE_HANDLER)
_FILE_HANDLER.close()
del _FILE_HANDLER del _FILE_HANDLER
if not os.path.isdir(dirname): if not os.path.isdir(dirname):
...@@ -137,10 +163,6 @@ def get_dir(): ...@@ -137,10 +163,6 @@ def get_dir():
# Will save log to log_dir/main_file_name/log.log by default # Will save log to log_dir/main_file_name/log.log by default
mod = sys.modules['__main__'] mod = sys.modules['__main__']
if hasattr(mod, '__file__'): _logger.info("Argv: " + ' '.join(sys.argv))
basename = os.path.basename(mod.__file__)
auto_dirname = os.path.join('log_dir', basename[:basename.rfind('.')])
shutil.rmtree(auto_dirname, ignore_errors=True)
set_dir(auto_dirname)
_logger.info("Argv: " + ' '.join(sys.argv))
...@@ -61,7 +61,7 @@ def get_gpu_count(): ...@@ -61,7 +61,7 @@ def get_gpu_count():
"""get avaliable gpu count """get avaliable gpu count
Returns: Returns:
gpu_count: int gpu_count: int
""" """
gpu_count = 0 gpu_count = 0
...@@ -77,7 +77,7 @@ def get_gpu_count(): ...@@ -77,7 +77,7 @@ def get_gpu_count():
logger.info( logger.info(
'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count)) 'CUDA_VISIBLE_DEVICES found gpu count: {}'.format(gpu_count))
except: except:
logger.warn('Cannot find available GPU devices, using CPU now.') logger.warning('Cannot find available GPU devices, using CPU now.')
gpu_count = 0 gpu_count = 0
else: else:
try: try:
...@@ -85,7 +85,7 @@ def get_gpu_count(): ...@@ -85,7 +85,7 @@ def get_gpu_count():
"-L"])).count('UUID') "-L"])).count('UUID')
logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count)) logger.info('nvidia-smi -L found gpu count: {}'.format(gpu_count))
except: except:
logger.warn('Cannot find available GPU devices, using CPU now.') logger.warning('Cannot find available GPU devices, using CPU now.')
gpu_count = 0 gpu_count = 0
return gpu_count return gpu_count
...@@ -100,7 +100,7 @@ def is_gpu_available(): ...@@ -100,7 +100,7 @@ def is_gpu_available():
if utils._HAS_FLUID: if utils._HAS_FLUID:
from paddle import fluid from paddle import fluid
if ret is True and not fluid.is_compiled_with_cuda(): if ret is True and not fluid.is_compiled_with_cuda():
logger.warn("Found non-empty CUDA_VISIBLE_DEVICES. \ logger.warning("Found non-empty CUDA_VISIBLE_DEVICES. \
But PARL found that Paddle was not complied with CUDA, which may cause issues." But PARL found that Paddle was not complied with CUDA, which may cause issues."
) )
return ret return ret
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册