未验证 提交 d683f331 编写于 作者: B Bo Zhou 提交者: GitHub

ping the master before connection (#278)

* ping the master before connection

* yapf

* fix comments

* remove the useless library

* install ping for the docker environment

* remove protobuf intallation

* remove evokit test
上级 c24e283c
......@@ -23,6 +23,5 @@ RUN apt-get install -y libgflags-dev libgoogle-glog-dev libomp-dev unzip
RUN apt-get install -y libgtest-dev && cd /usr/src/gtest && mkdir build \
&& cd build && cmake .. && make && cp libgtest*.a /usr/local/lib
RUN wget https://github.com/google/protobuf/releases/download/v2.4.1/protobuf-2.4.1.tar.gz \
&& tar -zxvf protobuf-2.4.1.tar.gz \
&& cd protobuf-2.4.1 && ./configure && make && make install
RUN apt-get update
RUN apt-get install -y iputils-ping
......@@ -134,19 +134,6 @@ EOF
rm -rf ${REPO_ROOT}/build
}
function run_evo_kit_test {
cd ${REPO_ROOT}/evo_kit
cat <<EOF
========================================
Running evo_kit test...
========================================
EOF
sh test/run_test.sh
rm -rf ${REPO_ROOT}/evo_kit/build
rm -rf ${REPO_ROOT}/evo_kit/libtorch
}
function main() {
set -e
local CMD=$1
......@@ -189,7 +176,6 @@ function main() {
/root/miniconda3/envs/empty_env/bin/pip install .
run_import_test
run_docs_test
run_evo_kit_test
;;
*)
print_usage
......
......@@ -146,7 +146,7 @@ def from_importance_weights(behaviour_actions_log_probs,
def recursively_scan(discounts, cs, deltas):
""" Recursively calculate vs_minus_v_xs according to following equation:
r""" Recursively calculate vs_minus_v_xs according to following equation:
vs_minus_v_xs(t) = deltas(t) + discounts(t) * cs(t) * vs_minus_v_xs(t + 1)
Args:
......
......@@ -20,6 +20,7 @@ import sys
import threading
import zmq
from parl.utils import to_str, to_byte, get_ip_address, logger
from parl.utils.communication import ping
from parl.remote import remote_constants
import time
......@@ -326,9 +327,14 @@ def connect(master_address, distributed_files=[]):
Exception: An exception is raised if the master node is not started.
"""
assert len(master_address.split(":")) == 2, "please input address in " +\
assert len(master_address.split(":")) == 2, "Please input address in " +\
"{ip}:{port} format"
global GLOBAL_CLIENT
addr = master_address.split(":")[0]
assert ping(
addr
) == 0, "Error occurs in connection with {}. PARL failed to ping this IP.".format(
master_address)
cur_process_id = os.getpid()
if GLOBAL_CLIENT is None:
GLOBAL_CLIENT = Client(master_address, cur_process_id,
......@@ -366,5 +372,5 @@ def disconnect():
GLOBAL_CLIENT = None
else:
logger.info(
"No client to be released. Please make sure that you have call `parl.connect`"
"No client to be released. Please make sure that you have called `parl.connect`"
)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import parl
import unittest
from parl.remote.client import disconnect
class TestPingMaster(unittest.TestCase):
def tearDown(self):
disconnect()
def test_throw_exception(self):
with self.assertRaises(AssertionError):
parl.connect("176.2.3.4:8080")
if __name__ == '__main__':
unittest.main()
......@@ -14,9 +14,14 @@
import cloudpickle
import pyarrow
import subprocess
import os
from parl.utils import _IS_WINDOWS
from parl.utils import SerializeError, DeserializeError
__all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return']
__all__ = [
'dumps_argument', 'loads_argument', 'dumps_return', 'loads_return', 'ping'
]
# Reference: https://github.com/apache/arrow/blob/f88474c84e7f02e226eb4cc32afef5e2bbc6e5b4/python/pyarrow/tests/test_serialization.py#L658-L682
......@@ -115,3 +120,23 @@ def loads_return(data):
raise DeserializeError(e)
return ret
#Reference: https://stackoverflow.com/questions/2953462/pinging-servers-in-python
def ping(host):
"""
Returns True if host (str) responds to a ping request.
Remember that a host may not respond to a ping (ICMP) request even if the host name is valid.
"""
# Option for the number of packets as a function of
param = '-n' if _IS_WINDOWS else '-c'
# Building the command. Ex: "ping -c 1 google.com"
command = ['ping', param, '1', host]
FNULL = open(os.devnull, 'w')
child = subprocess.Popen(command, stdout=FNULL, stderr=subprocess.STDOUT)
FNULL.close()
child.communicate()[0]
return child.returncode
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册