未验证 提交 f36b9a7f 编写于 作者: 1 123malin 提交者: GitHub

【Fleet2.0 Util】 add documents (#26698)

* test=develop, util documents
上级 e9a0fbff
...@@ -637,7 +637,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -637,7 +637,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
return "lo" return "lo"
def __start_kv_server(self, http_server_d, size_d): def __start_kv_server(self, http_server_d, size_d):
from paddle.distributed.fleet.utils import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
http_server = KVServer(int(self._http_ip_port[1]), size_d) http_server = KVServer(int(self._http_ip_port[1]), size_d)
http_server.start() http_server.start()
wait_seconds = 5 wait_seconds = 5
...@@ -651,6 +651,7 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker): ...@@ -651,6 +651,7 @@ class UserDefinedRoleMaker(PaddleCloudRoleMaker):
def __init__(self, is_collective=False, init_gloo=False, **kwargs): def __init__(self, is_collective=False, init_gloo=False, **kwargs):
super(UserDefinedRoleMaker, self).__init__( super(UserDefinedRoleMaker, self).__init__(
is_collective=is_collective, init_gloo=init_gloo, **kwargs) is_collective=is_collective, init_gloo=init_gloo, **kwargs)
self._init_gloo = init_gloo
def _user_defined_ps_env(self): def _user_defined_ps_env(self):
self._server_endpoints = self._kwargs.get("server_endpoints") self._server_endpoints = self._kwargs.get("server_endpoints")
......
...@@ -16,20 +16,18 @@ ...@@ -16,20 +16,18 @@
"""basic collective operations in python""" """basic collective operations in python"""
"""remote file system""" """remote file system"""
__all__ = ['UtilBase']
import numpy as np
import os
import subprocess
from paddle.fluid import core
from collections import OrderedDict
import paddle.fluid as fluid
from google.protobuf import text_format
from paddle.fluid import debugger
from paddle.fluid.framework import Program
from paddle.fluid.proto import framework_pb2
from ..utils.fs import FS, LocalFS, HDFSClient from ..utils.fs import FS, LocalFS, HDFSClient
from paddle.fluid.proto import framework_pb2
from paddle.fluid.framework import Program
from paddle.fluid import debugger
from google.protobuf import text_format
import paddle.fluid as fluid
from collections import OrderedDict
from paddle.fluid import core
import subprocess
import os
import numpy as np
__all__ = ['UtilBase']
class UtilFactory(object): class UtilFactory(object):
...@@ -53,7 +51,7 @@ class UtilBase(object): ...@@ -53,7 +51,7 @@ class UtilBase(object):
def _set_role_maker(self, role_maker): def _set_role_maker(self, role_maker):
self.role_maker = role_maker self.role_maker = role_maker
def set_file_system(self, fs_client): def _set_file_system(self, fs_client):
assert isinstance( assert isinstance(
fs_client, FS fs_client, FS
), "fs_client must be the instance of paddle.distributed.fleet.utils.FS" ), "fs_client must be the instance of paddle.distributed.fleet.utils.FS"
...@@ -87,36 +85,183 @@ class UtilBase(object): ...@@ -87,36 +85,183 @@ class UtilBase(object):
return _comm_world return _comm_world
def all_reduce(self, input, mode, comm_world="worker"): def all_reduce(self, input, mode, comm_world="worker"):
"""
All reduce `input` between specified collection. This is a distributed API.
Args:
input (list|numpy.array): The input variable to do all_reduce between specified collection.
mode (str): "sum" or "min" or "max".
comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Returns:
output(Numpy.array|None): A numpy array with the same shape as the `input` .
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
import numpy as np
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
input = [1, 2]
output = fleet_util.all_reduce(input, "sum", "server")
print(output)
# [2, 4]
elif fleet.is_worker():
input = np.array([3, 4])
output = fleet_util.all_reduce(input, "sum", "worker")
print(output)
# [6, 8]
output = fleet_util.all_reduce(input, "sum", "all")
print(output)
# [8, 12]
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_reduce(_comm_world, input, mode) return self.role_maker._all_reduce(_comm_world, input, mode)
def barrier(self, comm_world="worker"): def barrier(self, comm_world="worker"):
"""
Barrier between specified collection.
Args:
comm_world (str, optional): Collection used to execute barrier operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
fleet_util.barrier("server")
print("all server arrive here")
elif fleet.is_worker():
fleet_util.barrier("worker")
print("all server arrive here")
fleet_util.barrier("all")
print("all servers and workers arrive here")
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
self.role_maker._barrier(_comm_world) self.role_maker._barrier(_comm_world)
def all_gather(self, input, comm_world="worker"): def all_gather(self, input, comm_world="worker"):
"""
All gather `input` between specified collection.
Args:
input (Int|Float): The input variable to do all_gather between specified collection.
comm_world (str, optional): Collection used to execute all_reduce operation. Supported collections incude `worker` , `server` and `all` . The default is `worker` .
Returns:
output (List): A list of gathered values.
Examples:
.. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys
def train():
role = PaddleCloudRoleMaker(
is_collective=False,
init_gloo=True,
path="./tmp_gloo")
fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server():
input = fleet.server_index()
output = fleet_util.all_gather(input, "server")
print(output)
# output = [0, 1]
elif fleet.is_worker():
input = fleet.worker_index()
output = fleet_util.all_gather(input, "worker")
# output = [0, 1]
print(output)
output = fleet_util.all_gather(input, "all")
print(output)
# output = [0, 1, 0, 1]
if __name__ == "__main__":
train()
"""
_comm_world = self.__check_comm_world(comm_world) _comm_world = self.__check_comm_world(comm_world)
return self.role_maker._all_gather(_comm_world, input) return self.role_maker._all_gather(_comm_world, input)
def broadcast(self): def _broadcast(self):
pass pass
def scatter(self): def _scatter(self):
pass pass
def get_file_shard(self, files): def get_file_shard(self, files):
""" """
split files before distributed training, Split files before distributed training, and return filelist assigned to the current trainer.
example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e]. .. code-block:: text
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets [] example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer
0 gets [a, b, c] and trainer 1 gets [d, e].
example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets
[a], trainer 1 gets [b], trainer 2 gets []
Args: Args:
files(list): file list need to be read. files(list): File list need to be read.
Returns: Returns:
list: files belongs to this worker. List: Files belong to this worker.
Examples:
.. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role)
files = fleet_util.get_file_shard(["file1", "file2", "file3"])
# files = ["file1", "file2"]
""" """
if not isinstance(files, list): if not isinstance(files, list):
raise TypeError("files should be a list of file need to be read.") raise TypeError("files should be a list of file need to be read.")
...@@ -140,6 +285,30 @@ class UtilBase(object): ...@@ -140,6 +285,30 @@ class UtilBase(object):
return trainer_files[trainer_id] return trainer_files[trainer_id]
def print_on_rank(self, message, rank_id): def print_on_rank(self, message, rank_id):
"""
Woker of rank `rank_id` print some message.
Args:
message(str): Log to be printed.
rank_id(int): trainer id.
Examples:
.. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.UserDefinedRoleMaker(
is_collective=False,
init_gloo=False,
current_id=0,
role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role)
fleet_util.print_on_rank("I'm worker 0", 0)
"""
if self.role_maker.worker_index() != rank_id: if self.role_maker.worker_index() != rank_id:
return return
print(message) print(message)
...@@ -297,7 +466,7 @@ class UtilBase(object): ...@@ -297,7 +466,7 @@ class UtilBase(object):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
inference_program, feed_target_names, fetch_targets = \ inference_program, feed_target_names, fetch_targets = \
fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename, fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename,
params_filename=config.save_params_filename) params_filename=config.save_params_filename)
# check program vars and saved vars shape # check program vars and saved vars shape
orig_para_shape = { orig_para_shape = {
......
...@@ -87,7 +87,7 @@ def _parse_args(): ...@@ -87,7 +87,7 @@ def _parse_args():
see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2- see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/training/cluster_howto.html#permalink-8--nccl2-
''') ''')
#Optional arguments for the launch helper # Optional arguments for the launch helper
parser.add_argument( parser.add_argument(
"--ips", "--ips",
type=str, type=str,
...@@ -115,7 +115,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -115,7 +115,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
default="log", default="log",
help="The path for each process's log.If it's not set, the log will printed to default pipe." help="The path for each process's log.If it's not set, the log will printed to default pipe."
) )
#positional # positional
parser.add_argument( parser.add_argument(
"training_script", "training_script",
type=str, type=str,
...@@ -124,7 +124,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -124,7 +124,7 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"followed by all the arguments for the " "followed by all the arguments for the "
"training script") "training script")
#rest from the training program # rest from the training program
parser.add_argument('training_script_args', nargs=REMAINDER) parser.add_argument('training_script_args', nargs=REMAINDER)
return parser.parse_args() return parser.parse_args()
...@@ -138,7 +138,7 @@ def get_cluster_from_args(args, gpus): ...@@ -138,7 +138,7 @@ def get_cluster_from_args(args, gpus):
# node_ip = args.node_ip # node_ip = args.node_ip
assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \ assert node_ip in node_ips, "Can't find your local ip {%s} in node_ips: {%s}" \
% (node_ip, node_ips) % (node_ip, node_ips)
node_rank = node_ips.index(node_ip) node_rank = node_ips.index(node_ip)
logger.debug("parsed from args: node_ips:{} node_ip:{} node_rank:{}".format( logger.debug("parsed from args: node_ips:{} node_ip:{} node_rank:{}".format(
...@@ -280,7 +280,7 @@ def launch_ps(args): ...@@ -280,7 +280,7 @@ def launch_ps(args):
_, current_node_ip = get_host_name_ip() _, current_node_ip = get_host_name_ip()
assert current_node_ip in node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \ assert current_node_ip in node_ips, "Can't find your local ip {%s} in args.servers and args.workers ips: {%s}" \
% (current_node_ip, node_ips) % (current_node_ip, node_ips)
node_rank = node_ips.index(current_node_ip) node_rank = node_ips.index(current_node_ip)
logger.debug( logger.debug(
"parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}". "parsed from args: node_ips:{} current_node_ip:{} node_rank:{}, server_ports:{}".
...@@ -323,10 +323,12 @@ def launch_ps(args): ...@@ -323,10 +323,12 @@ def launch_ps(args):
for idx, cur_server in enumerate(pod.servers): for idx, cur_server in enumerate(pod.servers):
proc_env = { proc_env = {
"PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints, "PADDLE_PSERVERS_IP_PORT_LIST": server_endpoints,
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_PORT": cur_server.endpoint.split(":")[1], "PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER", "TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": cur_server.endpoint.split(":")[0] "POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -365,7 +367,8 @@ def launch_ps(args): ...@@ -365,7 +367,8 @@ def launch_ps(args):
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank) "PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -430,7 +433,11 @@ def launch(): ...@@ -430,7 +433,11 @@ def launch():
co_arg for co_arg in collective_args co_arg for co_arg in collective_args
if co_arg in " ".join(sys.argv[1:-1]) if co_arg in " ".join(sys.argv[1:-1])
] ]
cuda_device_num = fluid.core.get_cuda_device_count() if fluid.core.is_compiled_with_cuda():
cuda_device_num = fluid.core.get_cuda_device_count()
else:
cuda_device_num = 0
if len(has_ps_args) > 0 or cuda_device_num == 0: if len(has_ps_args) > 0 or cuda_device_num == 0:
logger.info( logger.info(
"Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}". "Run parameter-sever cpu mode. pserver arguments:{}, cuda count:{}".
......
...@@ -11,8 +11,3 @@ ...@@ -11,8 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .fs import *
from .http_server import KVHandler, KVHTTPServer, KVServer
#__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__
...@@ -32,10 +32,7 @@ import functools ...@@ -32,10 +32,7 @@ import functools
from pathlib import PurePosixPath, Path from pathlib import PurePosixPath, Path
import shutil import shutil
__all__ = [ __all__ = ['LocalFS', 'HDFSClient']
'FS', 'LocalFS', 'HDFSClient', 'ExecuteError', 'FSTimeOut',
'FSFileExistsError', 'FSFileNotExistsError', 'FSShellCmdAborted'
]
class ExecuteError(Exception): class ExecuteError(Exception):
...@@ -117,7 +114,37 @@ class FS(object): ...@@ -117,7 +114,37 @@ class FS(object):
class LocalFS(FS): class LocalFS(FS):
"""
A tool of local file system.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs, files = client.ls_dir("./")
"""
def ls_dir(self, fs_path): def ls_dir(self, fs_path):
"""
List directorys and files under `fs_path` .
Args:
fs_path(str): The local file path.
Returns:
Tuple: Return a 2-tuple, the first is a list of all its subdirectories,
and the second is a list of all its subfiles, e.g. ([subdirname1, subdirname1, ...], [filename1, filename2, ...]).
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs, files = client.ls_dir("./")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [], [] return [], []
...@@ -132,11 +159,46 @@ class LocalFS(FS): ...@@ -132,11 +159,46 @@ class LocalFS(FS):
return dirs, files return dirs, files
def mkdirs(self, fs_path): def mkdirs(self, fs_path):
"""
Create a remote HDFS directory.
Args:
fs_path(str): The local directory path.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_mkdirs")
client.delete("test_mkdirs")
"""
assert not os.path.isfile(fs_path), "{} is already a file".format( assert not os.path.isfile(fs_path), "{} is already a file".format(
fs_path) fs_path)
os.system("mkdir -p {}".format(fs_path)) os.system("mkdir -p {}".format(fs_path))
def rename(self, fs_src_path, fs_dst_path): def rename(self, fs_src_path, fs_dst_path):
"""
Rename the file.
Args:
fs_src_path(str): The actual name of the file or directory
fs_dst_path(str): The new name of the file or directory.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_rename_src")
print(client.is_exists("test_rename_src")) # True
client.rename("test_rename_src", "test_rename_dst")
print(client.is_exists("test_rename_src")) # False
print(client.is_exists("test_rename_dst")) # True
client.delete("test_rename_dst")
"""
os.rename(fs_src_path, fs_dst_path) os.rename(fs_src_path, fs_dst_path)
def _rmr(self, fs_path): def _rmr(self, fs_path):
...@@ -146,6 +208,21 @@ class LocalFS(FS): ...@@ -146,6 +208,21 @@ class LocalFS(FS):
os.remove(fs_path) os.remove(fs_path)
def delete(self, fs_path): def delete(self, fs_path):
"""
Delete the local file path, whether it's a file or directory.
Args:
fs_path(str): The local file path.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_localFS_mkdirs")
client.delete("test_localFS_mkdirs")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return return
...@@ -158,15 +235,88 @@ class LocalFS(FS): ...@@ -158,15 +235,88 @@ class LocalFS(FS):
return False return False
def is_file(self, fs_path): def is_file(self, fs_path):
"""
Whether the local file path is a file.
Args:
fs_path(str): The local file path.
Returns:
Bool: Return true if the path exists and it's a file, otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_is_file")
print(client.is_file("test_is_file")) # True
client.delete("test_is_file")
"""
return os.path.isfile(fs_path) return os.path.isfile(fs_path)
def is_dir(self, fs_path): def is_dir(self, fs_path):
"""
Whether the local file path is a directory.
Args:
fs_path(str): The local file path.
Returns:
Bool: Return true if the path exists and it's a directory, otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.mkdirs("test_is_dir")
print(client.is_dir("test_is_file")) # True
client.delete("test_is_dir")
"""
return os.path.isdir(fs_path) return os.path.isdir(fs_path)
def is_exist(self, fs_path): def is_exist(self, fs_path):
"""
Whether the local file path exists.
Args:
fs_path(str): The local file path.
Returns:
Bool: Wheter it's a file or directory, return true if the path exists,
otherwise return false.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
ret = local_fs.is_exist("test_is_exist")
"""
return os.path.exists(fs_path) return os.path.exists(fs_path)
def touch(self, fs_path, exist_ok=True): def touch(self, fs_path, exist_ok=True):
"""
Create a local file.
Args:
fs_path(str): The local file path.
exist_ok(bool): When `fs_path` exists, if `exist_ok` is set false,
program will throw an Exception. Default is true.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_touch")
client.delete("test_touch")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
if exist_ok: if exist_ok:
return return
...@@ -175,6 +325,26 @@ class LocalFS(FS): ...@@ -175,6 +325,26 @@ class LocalFS(FS):
return Path(fs_path).touch(exist_ok=True) return Path(fs_path).touch(exist_ok=True)
def mv(self, src_path, dst_path, overwrite=False, test_exists=False): def mv(self, src_path, dst_path, overwrite=False, test_exists=False):
"""
Move a local file or directory from `src_path` to `dst_path` .
Args:
src_path(str): Name of the file or directory, that's needed to be moved.
dst_path(str): Name of the file or directory to which to move to.
overwrite(bool): Whether to re-write `dst_path` if that exists. Default is False.
test_exists(bool): Check the existence of `src_path` and `dst_path` .
When `test_exists` is set true, if `src_path` doesn't exist or `dst_path` exists, program will throw an Excetption.
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
client.touch("test_mv_src")
client.mv("test_mv_src", "test_mv_dst")
client.delete("test_mv_dst")
"""
if not self.is_exist(src_path): if not self.is_exist(src_path):
raise FSFileNotExistsError raise FSFileNotExistsError
...@@ -188,7 +358,21 @@ class LocalFS(FS): ...@@ -188,7 +358,21 @@ class LocalFS(FS):
def list_dirs(self, fs_path): def list_dirs(self, fs_path):
""" """
list directory under fs_path, and only give the pure name, not include the fs_path Only list directorys under `fs_path` .
Args:
fs_path(str): The local file path.
Returns:
List: A list of all its subdirectories, e.g. [subdirname1, subdirname1, ...].
Examples:
.. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS
client = LocalFS()
subdirs = client.list_dirs("./")
""" """
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [] return []
...@@ -217,7 +401,7 @@ def _handle_errors(max_time_out=None): ...@@ -217,7 +401,7 @@ def _handle_errors(max_time_out=None):
while True: while True:
try: try:
return f(*args, **kwargs) return f(*args, **kwargs)
#important: only ExecuteError need to retry # important: only ExecuteError need to retry
except ExecuteError as e: except ExecuteError as e:
if time.time() - start >= time_out: if time.time() - start >= time_out:
raise FSTimeOut("args:{} timeout:{}".format( raise FSTimeOut("args:{} timeout:{}".format(
...@@ -236,12 +420,36 @@ def _handle_errors(max_time_out=None): ...@@ -236,12 +420,36 @@ def _handle_errors(max_time_out=None):
class HDFSClient(FS): class HDFSClient(FS):
"""
A tool of HDFS.
Args:
hadoop_home(str): Hadoop home.
configs(dict): Hadoop config. It is a dictionary and needs to contain the
keys: "fs.default.name" and "hadoop.job.ugi".
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.ls_dir("hdfs:/test_hdfs_client")
"""
def __init__( def __init__(
self, self,
hadoop_home, hadoop_home,
configs, configs,
time_out=5 * 60 * 1000, #ms time_out=5 * 60 * 1000, # ms
sleep_inter=1000): #ms sleep_inter=1000): # ms
# Raise exception if JAVA_HOME not exists. # Raise exception if JAVA_HOME not exists.
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
...@@ -272,6 +480,30 @@ class HDFSClient(FS): ...@@ -272,6 +480,30 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def list_dirs(self, fs_path): def list_dirs(self, fs_path):
"""
Only list directorys under `fs_path` .
Args:
fs_path(str): The HDFS file path.
Returns:
List: A list of all its subdirectories, e.g. [subdirname1, subdirname1, ...].
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
subdirs = client.list_dirs("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [] return []
...@@ -281,7 +513,29 @@ class HDFSClient(FS): ...@@ -281,7 +513,29 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def ls_dir(self, fs_path): def ls_dir(self, fs_path):
""" """
list directory under fs_path, and only give the pure name, not include the fs_path List directorys and files under `fs_path` .
Args:
fs_path(str): The HDFS file path.
Returns:
Tuple: Return a 2-tuple, the first element is the list of all its subdirectories,
and the second one is the list of all its subfiles, e.g. ([subdirname1, subdirname1, ...], [filename1, filename2, ...]).
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
subdirs, files = client.ls_dir("hdfs:/test_hdfs_client")
""" """
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return [], [] return [], []
...@@ -320,6 +574,30 @@ class HDFSClient(FS): ...@@ -320,6 +574,30 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def is_dir(self, fs_path): def is_dir(self, fs_path):
"""
Whether the remote HDFS path is a directory.
Args:
fs_path(str): The HDFS file path.
Returns:
Bool: Return true if the path exists and it's a directory, otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_file("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return False return False
...@@ -338,6 +616,30 @@ class HDFSClient(FS): ...@@ -338,6 +616,30 @@ class HDFSClient(FS):
return True return True
def is_file(self, fs_path): def is_file(self, fs_path):
"""
Whether the remote HDFS path is a file.
Args:
fs_path(str): The HDFS file path.
Returns:
Bool: Return true if the path exists and it's a file, otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_file("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return False return False
...@@ -345,6 +647,31 @@ class HDFSClient(FS): ...@@ -345,6 +647,31 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def is_exist(self, fs_path): def is_exist(self, fs_path):
"""
Whether the remote HDFS path exists.
Args:
fs_path(str): The hdfs file path.
Returns:
Bool: Whether it's is file or directory, return true if the path exists,
otherwise return false.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
ret = client.is_exist("hdfs:/test_hdfs_client")
"""
cmd = "ls {} ".format(fs_path) cmd = "ls {} ".format(fs_path)
ret, out = self._run_cmd(cmd, redirect_stderr=True) ret, out = self._run_cmd(cmd, redirect_stderr=True)
if ret != 0: if ret != 0:
...@@ -357,6 +684,28 @@ class HDFSClient(FS): ...@@ -357,6 +684,28 @@ class HDFSClient(FS):
# can't retry # can't retry
def upload(self, local_path, fs_path): def upload(self, local_path, fs_path):
"""
Upload the local path to remote HDFS.
Args:
local_path(str): The local path.
fs_path(str): The HDFS path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.upload("test_hdfs_client", "hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
raise FSFileExistsError("{} exists".format(fs_path)) raise FSFileExistsError("{} exists".format(fs_path))
...@@ -380,6 +729,28 @@ class HDFSClient(FS): ...@@ -380,6 +729,28 @@ class HDFSClient(FS):
# can't retry # can't retry
def download(self, fs_path, local_path): def download(self, fs_path, local_path):
"""
Download remote HDFS path to the local.
Args:
fs_path(str): The HDFS path.
local_path(str): The local path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.download("hdfs:/test_hdfs_client", "./")
"""
if self.is_exist(local_path): if self.is_exist(local_path):
raise FSFileExistsError("{} exists".format(local_path)) raise FSFileExistsError("{} exists".format(local_path))
...@@ -403,6 +774,27 @@ class HDFSClient(FS): ...@@ -403,6 +774,27 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def mkdirs(self, fs_path): def mkdirs(self, fs_path):
"""
Create a remote HDFS directory.
Args:
fs_path(str): The HDFS directory path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.mkdirs("hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
return return
...@@ -425,6 +817,30 @@ class HDFSClient(FS): ...@@ -425,6 +817,30 @@ class HDFSClient(FS):
raise ExecuteError(cmd) raise ExecuteError(cmd)
def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True): def mv(self, fs_src_path, fs_dst_path, overwrite=False, test_exists=True):
"""
Move a remote HDFS file or directory from `fs_src_path` to `fs_dst_path` .
Args:
fs_src_path(str): Name of the file or directory, that's needed to be moved.
fs_dst_path(str): Name of the file or directory to which to move to.
overwrite(bool): Whether to re-write `fs_dst_path` if that exists. Default is False.
test_exists(bool): Check the existence of `fs_src_path` and `fs_dst_path` . When `test_exists` is set true, if `fs_src_path` doesn't exist or `fs_dst_path` exists, program will throw an Excetption.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.mv("hdfs:/test_hdfs_client", "hdfs:/test_hdfs_client2")
"""
if overwrite and self.is_exist(fs_dst_path): if overwrite and self.is_exist(fs_dst_path):
self.delete(fs_dst_path) self.delete(fs_dst_path)
...@@ -467,6 +883,27 @@ class HDFSClient(FS): ...@@ -467,6 +883,27 @@ class HDFSClient(FS):
@_handle_errors() @_handle_errors()
def delete(self, fs_path): def delete(self, fs_path):
"""
Delete a remote HDFS path, whether it's a file or directory.
Args:
fs_path(str): The HDFS file path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.delete("hdfs:/test_hdfs_client")
"""
if not self.is_exist(fs_path): if not self.is_exist(fs_path):
return return
...@@ -477,6 +914,27 @@ class HDFSClient(FS): ...@@ -477,6 +914,27 @@ class HDFSClient(FS):
return self._rm(fs_path) return self._rm(fs_path)
def touch(self, fs_path, exist_ok=True): def touch(self, fs_path, exist_ok=True):
"""
Create a remote HDFS file.
Args:
fs_path(str): The HDFS file path.
Examples:
.. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = {
"fs.default.name": "hdfs://xxx.hadoop.com:54310",
"hadoop.job.ugi": "hello,hello123"
}
client = HDFSClient(hadoop_home, configs)
client.touch("hdfs:/test_hdfs_client")
"""
if self.is_exist(fs_path): if self.is_exist(fs_path):
if exist_ok: if exist_ok:
return return
......
...@@ -98,7 +98,7 @@ class AutoCheckpointChecker(object): ...@@ -98,7 +98,7 @@ class AutoCheckpointChecker(object):
self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache") self._fs_cache = os.getenv("PADDLE_EDL_FS_CACHE", ".cache")
self._save_checkpoint_inter = int( self._save_checkpoint_inter = int(
os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) #s os.getenv("PADDLE_EDL_SAVE_CHECKPOINT_INTER", "900")) # s
if not self._ce_test: if not self._ce_test:
assert len(self._hdfs_home) > 3 and \ assert len(self._hdfs_home) > 3 and \
...@@ -132,7 +132,7 @@ class AutoCheckpointChecker(object): ...@@ -132,7 +132,7 @@ class AutoCheckpointChecker(object):
if in_dygraph_mode(): if in_dygraph_mode():
return False return False
return self._run_env is not None and \ return self._run_env is not None and \
self._platform is not None and \ self._platform is not None and \
self._job_id is not None and \ self._job_id is not None and \
self._hdfs_home is not None and \ self._hdfs_home is not None and \
......
...@@ -19,7 +19,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,7 +19,7 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
......
...@@ -67,13 +67,13 @@ class AutoCheckpointTestDist(AutoCheckPointACLBase): ...@@ -67,13 +67,13 @@ class AutoCheckpointTestDist(AutoCheckPointACLBase):
save_dir = "./run_save_0" save_dir = "./run_save_0"
fs.delete(save_dir) fs.delete(save_dir)
#basic # basic
exe, main_prog, startup_prog = self._generate() exe, main_prog, startup_prog = self._generate()
compiled, data_loader, optimizer, loss, image, label = \ compiled, data_loader, optimizer, loss, image, label = \
self._init_env(exe, main_prog, startup_prog, minimize=False) self._init_env(exe, main_prog, startup_prog, minimize=False)
#fleet # fleet
os.environ["TRAINING_ROLE"] = "TRAINER" os.environ["TRAINING_ROLE"] = "TRAINER"
os.environ["PADDLE_TRAINER_ID"] = "0" os.environ["PADDLE_TRAINER_ID"] = "0"
os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:6070" os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:6070"
......
...@@ -40,9 +40,9 @@ class TestCloudRoleMaker(unittest.TestCase): ...@@ -40,9 +40,9 @@ class TestCloudRoleMaker(unittest.TestCase):
from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib
from paddle.fluid.incubate.fleet.base.role_maker import \ from paddle.fluid.incubate.fleet.base.role_maker import \
GeneralRoleMaker GeneralRoleMaker
from paddle.distributed.fleet.utils import KVHandler from paddle.distributed.fleet.utils.http_server import KVHandler
from paddle.distributed.fleet.utils import KVServer from paddle.distributed.fleet.utils.http_server import KVServer
from paddle.distributed.fleet.utils import KVHTTPServer from paddle.distributed.fleet.utils.http_server import KVHTTPServer
except: except:
print("warning: no fleet, skip test_pslib_4") print("warning: no fleet, skip test_pslib_4")
return return
......
...@@ -81,12 +81,12 @@ class TestFleetUtil(unittest.TestCase): ...@@ -81,12 +81,12 @@ class TestFleetUtil(unittest.TestCase):
self.assertEqual(user_id, 10) self.assertEqual(user_id, 10)
def test_fs(self): def test_fs(self):
from paddle.distributed.fleet.utils import LocalFS from paddle.distributed.fleet.utils.fs import LocalFS
fs = LocalFS() fs = LocalFS()
dirs, files = fs.ls_dir("test_tmp") dirs, files = fs.ls_dir("test_tmp")
dirs, files = fs.ls_dir("./") dirs, files = fs.ls_dir("./")
self.assertFalse(fs.need_upload_download()) self.assertFalse(fs.need_upload_download())
fleet_util.set_file_system(fs) fleet_util._set_file_system(fs)
def test_barrier(self): def test_barrier(self):
try: try:
......
...@@ -20,7 +20,7 @@ import os ...@@ -20,7 +20,7 @@ import os
import sys import sys
import inspect import inspect
from paddle.distributed.fleet.utils import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
class FSTest(unittest.TestCase): class FSTest(unittest.TestCase):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest1(FSTestBase): class FSTest1(FSTestBase):
def test_timeout(self): def test_timeout(self):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest2(FSTestBase): class FSTest2(FSTestBase):
def test_hdfs(self): def test_hdfs(self):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker import paddle.fluid.incubate.fleet.base.role_maker as role_maker
...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet ...@@ -19,12 +20,10 @@ from paddle.fluid.incubate.fleet.collective import CollectiveOptimizer, fleet
import os import os
import sys import sys
from paddle.distributed.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError from paddle.distributed.fleet.utils.fs import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError
java_home = os.environ["JAVA_HOME"] java_home = os.environ["JAVA_HOME"]
from paddle.fluid.tests.unittests.hdfs_test_utils import FSTestBase
class FSTest3(FSTestBase): class FSTest3(FSTestBase):
def test_hdfs(self): def test_hdfs(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册