未验证 提交 32ad4f90 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】 Usages Change: from fleet.util() to fleet.util (#27468)

* test=develop, bug fix
上级 df7fabee
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# TODO: define distributed api under this directory, # TODO: define distributed api under this directory,
from .base.role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker from .base.role_maker import Role, UserDefinedRoleMaker, PaddleCloudRoleMaker
from .base.distributed_strategy import DistributedStrategy from .base.distributed_strategy import DistributedStrategy
from .base.fleet_base import Fleet from .base.fleet_base import Fleet
from .base.util_factory import UtilBase from .base.util_factory import UtilBase
...@@ -26,6 +26,7 @@ __all__ = [ ...@@ -26,6 +26,7 @@ __all__ = [
"UserDefinedRoleMaker", "UserDefinedRoleMaker",
"PaddleCloudRoleMaker", "PaddleCloudRoleMaker",
"Fleet", "Fleet",
"Role",
] ]
fleet = Fleet() fleet = Fleet()
...@@ -39,8 +40,7 @@ server_num = fleet.server_num ...@@ -39,8 +40,7 @@ server_num = fleet.server_num
server_index = fleet.server_index server_index = fleet.server_index
server_endpoints = fleet.server_endpoints server_endpoints = fleet.server_endpoints
is_server = fleet.is_server is_server = fleet.is_server
set_util = fleet.set_util util = UtilBase()
util = fleet.util
barrier_worker = fleet.barrier_worker barrier_worker = fleet.barrier_worker
init_worker = fleet.init_worker init_worker = fleet.init_worker
init_server = fleet.init_server init_server = fleet.init_server
......
...@@ -23,7 +23,6 @@ from .strategy_compiler import StrategyCompiler ...@@ -23,7 +23,6 @@ from .strategy_compiler import StrategyCompiler
from .distributed_strategy import DistributedStrategy from .distributed_strategy import DistributedStrategy
from .meta_optimizer_factory import MetaOptimizerFactory from .meta_optimizer_factory import MetaOptimizerFactory
from .runtime_factory import RuntimeFactory from .runtime_factory import RuntimeFactory
from .util_factory import UtilFactory
from paddle.fluid.wrapped_decorator import wrap_decorator from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
...@@ -120,7 +119,6 @@ class Fleet(object): ...@@ -120,7 +119,6 @@ class Fleet(object):
self.strategy_compiler = None self.strategy_compiler = None
self._is_collective = False self._is_collective = False
self._runtime_handle = None self._runtime_handle = None
self._util = None
def init(self, role_maker=None, is_collective=False): def init(self, role_maker=None, is_collective=False):
""" """
...@@ -182,6 +180,9 @@ class Fleet(object): ...@@ -182,6 +180,9 @@ class Fleet(object):
format(type(role_maker))) format(type(role_maker)))
self._role_maker._generate_role() self._role_maker._generate_role()
import paddle.distributed.fleet as fleet
fleet.util._set_role_maker(self._role_maker)
self.strategy_compiler = StrategyCompiler() self.strategy_compiler = StrategyCompiler()
if paddle.fluid.framework.in_dygraph_mode(): if paddle.fluid.framework.in_dygraph_mode():
if parallel_helper._is_parallel_ctx_initialized(): if parallel_helper._is_parallel_ctx_initialized():
...@@ -353,29 +354,6 @@ class Fleet(object): ...@@ -353,29 +354,6 @@ class Fleet(object):
return self._role_maker._is_server( return self._role_maker._is_server(
) or self._role_maker._is_heter_worker() ) or self._role_maker._is_heter_worker()
def set_util(self, util):
self._util = util
def util(self):
"""
Utility functions that can be used under certain runtime
return util
Returns:
UtilBase: instance of UtilBase, can use distributed ops/tools easily.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
fleet.init()
util = fleet.util
files = ["1.log", "2.log", "3.log", "4.log"]
files = util.get_file_shard()
"""
return self._util
def barrier_worker(self): def barrier_worker(self):
""" """
barrier all workers barrier all workers
...@@ -1102,7 +1080,7 @@ class Fleet(object): ...@@ -1102,7 +1080,7 @@ class Fleet(object):
if self._runtime_handle is None: if self._runtime_handle is None:
self._runtime_handle = RuntimeFactory()._create_runtime(context) self._runtime_handle = RuntimeFactory()._create_runtime(context)
if self._util is None: import paddle.distributed.fleet as fleet
self._util = UtilFactory()._create_util(context) fleet.util._set_strategy(context["valid_strategy"])
return optimize_ops, params_grads return optimize_ops, params_grads
...@@ -73,11 +73,13 @@ class UtilBase(object): ...@@ -73,11 +73,13 @@ class UtilBase(object):
.. code-block:: python .. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` . # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys import sys
import numpy as np import numpy as np
import os
os.environ["PADDLE_WITH_GLOO"] = "2"
def train(): def train():
role = PaddleCloudRoleMaker( role = PaddleCloudRoleMaker(
...@@ -85,19 +87,18 @@ class UtilBase(object): ...@@ -85,19 +87,18 @@ class UtilBase(object):
init_gloo=True, init_gloo=True,
path="./tmp_gloo") path="./tmp_gloo")
fleet.init(role) fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server(): if fleet.is_server():
input = [1, 2] input = [1, 2]
output = fleet_util.all_reduce(input, "sum", "server") output = fleet.util.all_reduce(input, "sum", "server")
print(output) print(output)
# [2, 4] # [2, 4]
elif fleet.is_worker(): elif fleet.is_worker():
input = np.array([3, 4]) input = np.array([3, 4])
output = fleet_util.all_reduce(input, "sum", "worker") output = fleet.util.all_reduce(input, "sum", "worker")
print(output) print(output)
# [6, 8] # [6, 8]
output = fleet_util.all_reduce(input, "sum", "all") output = fleet.util.all_reduce(input, "sum", "all")
print(output) print(output)
# [8, 12] # [8, 12]
if __name__ == "__main__": if __name__ == "__main__":
...@@ -117,10 +118,12 @@ class UtilBase(object): ...@@ -117,10 +118,12 @@ class UtilBase(object):
.. code-block:: python .. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` . # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys import sys
import os
os.environ["PADDLE_WITH_GLOO"] = "2"
def train(): def train():
role = PaddleCloudRoleMaker( role = PaddleCloudRoleMaker(
...@@ -128,15 +131,14 @@ class UtilBase(object): ...@@ -128,15 +131,14 @@ class UtilBase(object):
init_gloo=True, init_gloo=True,
path="./tmp_gloo") path="./tmp_gloo")
fleet.init(role) fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server(): if fleet.is_server():
fleet_util.barrier("server") fleet.util.barrier("server")
print("all server arrive here") print("all server arrive here")
elif fleet.is_worker(): elif fleet.is_worker():
fleet_util.barrier("worker") fleet.util.barrier("worker")
print("all server arrive here") print("all server arrive here")
fleet_util.barrier("all") fleet.util.barrier("all")
print("all servers and workers arrive here") print("all servers and workers arrive here")
if __name__ == "__main__": if __name__ == "__main__":
...@@ -160,10 +162,12 @@ class UtilBase(object): ...@@ -160,10 +162,12 @@ class UtilBase(object):
.. code-block:: python .. code-block:: python
# Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` . # Save the following code in `train.py` , and then execute the command `fleetrun --server_num 2 --worker_num 2 train.py` .
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import PaddleCloudRoleMaker from paddle.distributed.fleet import PaddleCloudRoleMaker
import sys import sys
import os
os.environ["PADDLE_WITH_GLOO"] = "2"
def train(): def train():
role = PaddleCloudRoleMaker( role = PaddleCloudRoleMaker(
...@@ -171,19 +175,18 @@ class UtilBase(object): ...@@ -171,19 +175,18 @@ class UtilBase(object):
init_gloo=True, init_gloo=True,
path="./tmp_gloo") path="./tmp_gloo")
fleet.init(role) fleet.init(role)
fleet_util._set_role_maker(role)
if fleet.is_server(): if fleet.is_server():
input = fleet.server_index() input = fleet.server_index()
output = fleet_util.all_gather(input, "server") output = fleet.util.all_gather(input, "server")
print(output) print(output)
# output = [0, 1] # output = [0, 1]
elif fleet.is_worker(): elif fleet.is_worker():
input = fleet.worker_index() input = fleet.worker_index()
output = fleet_util.all_gather(input, "worker") output = fleet.util.all_gather(input, "worker")
# output = [0, 1] # output = [0, 1]
print(output) print(output)
output = fleet_util.all_gather(input, "all") output = fleet.util.all_gather(input, "all")
print(output) print(output)
# output = [0, 1, 0, 1] # output = [0, 1, 0, 1]
...@@ -220,18 +223,20 @@ class UtilBase(object): ...@@ -220,18 +223,20 @@ class UtilBase(object):
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker from paddle.distributed.fleet import UserDefinedRoleMaker
role = role_maker.UserDefinedRoleMaker( role = UserDefinedRoleMaker(
is_collective=False, is_collective=False,
init_gloo=False, init_gloo=False,
current_id=0, current_id=0,
role=role_maker.Role.WORKER, role=fleet.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"], worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role) fleet.init(role)
files = fleet_util.get_file_shard(["file1", "file2", "file3"])
files = fleet.util.get_file_shard(["file1", "file2", "file3"])
print(files)
# files = ["file1", "file2"] # files = ["file1", "file2"]
""" """
if not isinstance(files, list): if not isinstance(files, list):
...@@ -267,18 +272,19 @@ class UtilBase(object): ...@@ -267,18 +272,19 @@ class UtilBase(object):
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.base.util_factory import fleet_util import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker from paddle.distributed.fleet import UserDefinedRoleMaker
role = role_maker.UserDefinedRoleMaker( role = UserDefinedRoleMaker(
is_collective=False, is_collective=False,
init_gloo=False, init_gloo=False,
current_id=0, current_id=0,
role=role_maker.Role.WORKER, role=fleet.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"], worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role) fleet.init(role)
fleet_util.print_on_rank("I'm worker 0", 0)
fleet.util.print_on_rank("I'm worker 0", 0)
""" """
if self.role_maker._worker_index() != rank_id: if self.role_maker._worker_index() != rank_id:
return return
...@@ -577,6 +583,3 @@ class UtilBase(object): ...@@ -577,6 +583,3 @@ class UtilBase(object):
print("fetch_targets name: %s" % v.name) print("fetch_targets name: %s" % v.name)
print("fetch_targets: {}".format(results[i])) print("fetch_targets: {}".format(results[i]))
return results return results
fleet_util = UtilFactory()._create_util(None)
...@@ -181,8 +181,8 @@ def get_gpus(gpus): ...@@ -181,8 +181,8 @@ def get_gpus(gpus):
cuda_visible_devices_list = cuda_visible_devices.split(',') cuda_visible_devices_list = cuda_visible_devices.split(',')
for x in gpus.split(','): for x in gpus.split(','):
assert x in cuda_visible_devices_list, "Can't find "\ assert x in cuda_visible_devices_list, "Can't find "\
"your gpus %s in CUDA_VISIBLE_DEVICES[%s]."\ "your gpus %s in CUDA_VISIBLE_DEVICES[%s]."\
% (x, cuda_visible_devices) % (x, cuda_visible_devices)
res_gpus = [ res_gpus = [
cuda_visible_devices_list.index(x.strip()) cuda_visible_devices_list.index(x.strip())
for x in gpus.split(',') for x in gpus.split(',')
...@@ -348,8 +348,7 @@ def launch_ps(args): ...@@ -348,8 +348,7 @@ def launch_ps(args):
"PADDLE_PORT": cur_server.endpoint.split(":")[1], "PADDLE_PORT": cur_server.endpoint.split(":")[1],
"TRAINING_ROLE": "PSERVER", "TRAINING_ROLE": "PSERVER",
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"POD_IP": cur_server.endpoint.split(":")[0], "POD_IP": cur_server.endpoint.split(":")[0]
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
...@@ -388,8 +387,7 @@ def launch_ps(args): ...@@ -388,8 +387,7 @@ def launch_ps(args):
"PADDLE_TRAINER_ENDPOINTS": worker_endpoints, "PADDLE_TRAINER_ENDPOINTS": worker_endpoints,
"PADDLE_TRAINERS_NUM": str(worker_num), "PADDLE_TRAINERS_NUM": str(worker_num),
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank), "PADDLE_TRAINER_ID": str(cur_worker.rank)
"PADDLE_WITH_GLOO": "1"
} }
current_env.update(proc_env) current_env.update(proc_env)
......
...@@ -11,3 +11,5 @@ ...@@ -11,3 +11,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .fs import LocalFS, HDFSClient
...@@ -120,7 +120,7 @@ class LocalFS(FS): ...@@ -120,7 +120,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
subdirs, files = client.ls_dir("./") subdirs, files = client.ls_dir("./")
...@@ -140,7 +140,7 @@ class LocalFS(FS): ...@@ -140,7 +140,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
subdirs, files = client.ls_dir("./") subdirs, files = client.ls_dir("./")
...@@ -160,7 +160,7 @@ class LocalFS(FS): ...@@ -160,7 +160,7 @@ class LocalFS(FS):
def mkdirs(self, fs_path): def mkdirs(self, fs_path):
""" """
Create a remote HDFS directory. Create a local directory.
Args: Args:
fs_path(str): The local directory path. fs_path(str): The local directory path.
...@@ -168,7 +168,7 @@ class LocalFS(FS): ...@@ -168,7 +168,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.mkdirs("test_mkdirs") client.mkdirs("test_mkdirs")
...@@ -189,7 +189,7 @@ class LocalFS(FS): ...@@ -189,7 +189,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.touch("test_rename_src") client.touch("test_rename_src")
...@@ -217,7 +217,7 @@ class LocalFS(FS): ...@@ -217,7 +217,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.mkdirs("test_localFS_mkdirs") client.mkdirs("test_localFS_mkdirs")
...@@ -247,7 +247,7 @@ class LocalFS(FS): ...@@ -247,7 +247,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.touch("test_is_file") client.touch("test_is_file")
...@@ -269,7 +269,7 @@ class LocalFS(FS): ...@@ -269,7 +269,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.mkdirs("test_is_dir") client.mkdirs("test_is_dir")
...@@ -292,7 +292,7 @@ class LocalFS(FS): ...@@ -292,7 +292,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
ret = local_fs.is_exist("test_is_exist") ret = local_fs.is_exist("test_is_exist")
...@@ -311,7 +311,7 @@ class LocalFS(FS): ...@@ -311,7 +311,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.touch("test_touch") client.touch("test_touch")
...@@ -332,13 +332,11 @@ class LocalFS(FS): ...@@ -332,13 +332,11 @@ class LocalFS(FS):
src_path(str): Name of the file or directory, that's needed to be moved. src_path(str): Name of the file or directory, that's needed to be moved.
dst_path(str): Name of the file or directory to which to move to. dst_path(str): Name of the file or directory to which to move to.
overwrite(bool): Whether to re-write `dst_path` if that exists. Default is False. overwrite(bool): Whether to re-write `dst_path` if that exists. Default is False.
test_exists(bool): Check the existence of `src_path` and `dst_path` .
When `test_exists` is set true, if `src_path` doesn't exist or `dst_path` exists, program will throw an Excetption.
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
client.touch("test_mv_src") client.touch("test_mv_src")
...@@ -369,7 +367,7 @@ class LocalFS(FS): ...@@ -369,7 +367,7 @@ class LocalFS(FS):
Examples: Examples:
.. code-block:: python .. code-block:: python
from paddle.distributed.fleet.utils.fs import LocalFS from paddle.distributed.fleet.utils import LocalFS
client = LocalFS() client = LocalFS()
subdirs = client.list_dirs("./") subdirs = client.list_dirs("./")
...@@ -432,7 +430,7 @@ class HDFSClient(FS): ...@@ -432,7 +430,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -493,7 +491,7 @@ class HDFSClient(FS): ...@@ -493,7 +491,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -526,7 +524,7 @@ class HDFSClient(FS): ...@@ -526,7 +524,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -587,7 +585,7 @@ class HDFSClient(FS): ...@@ -587,7 +585,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -629,7 +627,7 @@ class HDFSClient(FS): ...@@ -629,7 +627,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -661,7 +659,7 @@ class HDFSClient(FS): ...@@ -661,7 +659,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -695,7 +693,7 @@ class HDFSClient(FS): ...@@ -695,7 +693,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -740,7 +738,7 @@ class HDFSClient(FS): ...@@ -740,7 +738,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -784,7 +782,7 @@ class HDFSClient(FS): ...@@ -784,7 +782,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -830,7 +828,7 @@ class HDFSClient(FS): ...@@ -830,7 +828,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -893,7 +891,7 @@ class HDFSClient(FS): ...@@ -893,7 +891,7 @@ class HDFSClient(FS):
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
...@@ -919,12 +917,14 @@ class HDFSClient(FS): ...@@ -919,12 +917,14 @@ class HDFSClient(FS):
Args: Args:
fs_path(str): The HDFS file path. fs_path(str): The HDFS file path.
exist_ok(bool): When `fs_path` exists, if `exist_ok` is set false,
program will throw an Exception. Default is true.
Examples: Examples:
.. code-block:: text .. code-block:: text
from paddle.distributed.fleet.utils.fs import HDFSClient from paddle.distributed.fleet.utils import HDFSClient
hadoop_home = "/home/client/hadoop-client/hadoop/" hadoop_home = "/home/client/hadoop-client/hadoop/"
configs = { configs = {
......
...@@ -28,7 +28,6 @@ import numpy as np ...@@ -28,7 +28,6 @@ import numpy as np
import ctr_dataset_reader import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from paddle.distributed.fleet.base.util_factory import fleet_util
paddle.enable_static() paddle.enable_static()
...@@ -180,13 +179,13 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -180,13 +179,13 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fetch_list=[self.avg_cost.name]) fetch_list=[self.avg_cost.name])
loss_val = np.mean(loss_val) loss_val = np.mean(loss_val)
# TODO(randomly fail) # TODO(randomly fail)
# reduce_output = fleet_util.all_reduce( # reduce_output = fleet.util.all_reduce(
# np.array(loss_val), mode="sum") # np.array(loss_val), mode="sum")
# loss_all_trainer = fleet_util.all_gather(float(loss_val)) # loss_all_trainer = fleet.util.all_gather(float(loss_val))
# loss_val = float(reduce_output) / len(loss_all_trainer) # loss_val = float(reduce_output) / len(loss_all_trainer)
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val) loss_val)
fleet_util.print_on_rank(message, 0) fleet.util.print_on_rank(message, 0)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
except fluid.core.EOFException: except fluid.core.EOFException:
......
...@@ -29,7 +29,6 @@ import numpy as np ...@@ -29,7 +29,6 @@ import numpy as np
import ctr_dataset_reader import ctr_dataset_reader
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
from paddle.distributed.fleet.base.util_factory import fleet_util
# Fix seed for test # Fix seed for test
fluid.default_startup_program().random_seed = 1 fluid.default_startup_program().random_seed = 1
...@@ -76,13 +75,13 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -76,13 +75,13 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
loss_val = exe.run(program=fleet.main_program, loss_val = exe.run(program=fleet.main_program,
fetch_list=[self.avg_cost.name]) fetch_list=[self.avg_cost.name])
loss_val = np.mean(loss_val) loss_val = np.mean(loss_val)
reduce_output = fleet_util.all_reduce( reduce_output = fleet.util.all_reduce(
np.array(loss_val), mode="sum") np.array(loss_val), mode="sum")
loss_all_trainer = fleet_util.all_gather(float(loss_val)) loss_all_trainer = fleet.util.all_gather(float(loss_val))
loss_val = float(reduce_output) / len(loss_all_trainer) loss_val = float(reduce_output) / len(loss_all_trainer)
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val) loss_val)
fleet_util.print_on_rank(message, 0) fleet.util.print_on_rank(message, 0)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
except fluid.core.EOFException: except fluid.core.EOFException:
......
...@@ -29,7 +29,6 @@ import numpy as np ...@@ -29,7 +29,6 @@ import numpy as np
import ctr_dataset_reader import ctr_dataset_reader
from test_dist_fleet_heter_base import runtime_main, FleetDistHeterRunnerBase from test_dist_fleet_heter_base import runtime_main, FleetDistHeterRunnerBase
from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader from dist_fleet_ctr import TestDistCTR2x2, fake_ctr_reader
from paddle.distributed.fleet.base.util_factory import fleet_util
paddle.enable_static() paddle.enable_static()
...@@ -182,7 +181,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -182,7 +181,7 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
thread_num = int(os.getenv("CPU_NUM", 2)) thread_num = int(os.getenv("CPU_NUM", 2))
batch_size = 128 batch_size = 128
filelist = fleet_util.get_file_shard(train_file_list) filelist = fleet.util.get_file_shard(train_file_list)
print("filelist: {}".format(filelist)) print("filelist: {}".format(filelist))
# config dataset # config dataset
......
...@@ -32,7 +32,6 @@ import os ...@@ -32,7 +32,6 @@ import os
import signal import signal
from functools import reduce from functools import reduce
from test_dist_fleet_base import runtime_main, FleetDistRunnerBase from test_dist_fleet_base import runtime_main, FleetDistRunnerBase
from paddle.distributed.fleet.base.util_factory import fleet_util
paddle.enable_static() paddle.enable_static()
...@@ -198,7 +197,7 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase): ...@@ -198,7 +197,7 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase):
def net(self, args, batch_size=4, lr=0.01): def net(self, args, batch_size=4, lr=0.01):
avg_cost, _, predict, self.reader = \ avg_cost, _, predict, self.reader = \
train_network(batch_size=batch_size, is_distributed=False, train_network(batch_size=batch_size, is_distributed=False,
is_sparse=True, is_self_contained_lr=False, is_pyreader=(args.reader == "pyreader")) is_sparse=True, is_self_contained_lr=False, is_pyreader=(args.reader == "pyreader"))
self.avg_cost = avg_cost self.avg_cost = avg_cost
self.predict = predict self.predict = predict
...@@ -238,7 +237,7 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase): ...@@ -238,7 +237,7 @@ class TestDistSimnetBow2x2(FleetDistRunnerBase):
loss_val = np.mean(loss_val) loss_val = np.mean(loss_val)
message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id,
loss_val) loss_val)
fleet_util.print_on_rank(message, 0) fleet.util.print_on_rank(message, 0)
pass_time = time.time() - pass_start pass_time = time.time() - pass_start
except fluid.core.EOFException: except fluid.core.EOFException:
......
...@@ -34,8 +34,7 @@ import unittest ...@@ -34,8 +34,7 @@ import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.base.util_factory import fleet_util import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import fleet
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory
__all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main'] __all__ = ['FleetDistRunnerBase', 'TestFleetBase', 'runtime_main']
...@@ -97,7 +96,7 @@ class FleetDistRunnerBase(object): ...@@ -97,7 +96,7 @@ class FleetDistRunnerBase(object):
self.dump_fields_path = os.getenv("dump_fields_path", "") self.dump_fields_path = os.getenv("dump_fields_path", "")
debug = int(os.getenv("Debug", "0")) debug = int(os.getenv("Debug", "0"))
# TODO(update strategy to support dump params) # TODO(update strategy to support dump params)
if False: #debug: if False: # debug:
self.strategy.set_debug_opt({ self.strategy.set_debug_opt({
"dump_param": self.dump_param, "dump_param": self.dump_param,
"dump_fields": self.dump_fields, "dump_fields": self.dump_fields,
...@@ -372,8 +371,6 @@ def runtime_main(test_class): ...@@ -372,8 +371,6 @@ def runtime_main(test_class):
strategy = model.build_strategy(args) strategy = model.build_strategy(args)
avg_cost = model.net(args) avg_cost = model.net(args)
model.build_optimizer(avg_cost, strategy) model.build_optimizer(avg_cost, strategy)
fleet_util._set_strategy(strategy)
fleet_util._set_role_maker(role)
if args.role == "pserver": if args.role == "pserver":
model.run_pserver(args) model.run_pserver(args)
else: else:
......
...@@ -34,8 +34,7 @@ import unittest ...@@ -34,8 +34,7 @@ import unittest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.base.util_factory import fleet_util import paddle.distributed.fleet as fleet
from paddle.distributed.fleet import fleet
__all__ = ['FleetDistHeterRunnerBase', 'TestFleetHeterBase', 'runtime_main'] __all__ = ['FleetDistHeterRunnerBase', 'TestFleetHeterBase', 'runtime_main']
...@@ -376,8 +375,6 @@ def runtime_main(test_class): ...@@ -376,8 +375,6 @@ def runtime_main(test_class):
strategy = model.build_strategy(args) strategy = model.build_strategy(args)
avg_cost = model.net(args) avg_cost = model.net(args)
model.build_optimizer(avg_cost, strategy) model.build_optimizer(avg_cost, strategy)
fleet_util._set_strategy(strategy)
fleet_util._set_role_maker(role)
if args.role == "pserver" or args.role == "heter_trainer": if args.role == "pserver" or args.role == "heter_trainer":
model.run_pserver(args) model.run_pserver(args)
......
...@@ -19,7 +19,6 @@ import os ...@@ -19,7 +19,6 @@ import os
import math import math
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
from paddle.distributed.fleet.base.util_factory import fleet_util
from paddle.distributed.fleet import fleet from paddle.distributed.fleet import fleet
import paddle import paddle
......
...@@ -107,7 +107,7 @@ class TestFleetBase(unittest.TestCase): ...@@ -107,7 +107,7 @@ class TestFleetBase(unittest.TestCase):
def test_util(self): def test_util(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
self.assertEqual(fleet.util(), None) self.assertNotEqual(fleet.util, None)
def test_barrier_worker(self): def test_barrier_worker(self):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
......
...@@ -436,12 +436,12 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -436,12 +436,12 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
comm_world = "server" comm_world = "server"
fleet.util().barrier(comm_world) fleet.util.barrier(comm_world)
gather = fleet.util().all_gather(1, comm_world) gather = fleet.util.all_gather(1, comm_world)
self.assertEqual(gather[0], 1) self.assertEqual(gather[0], 1)
all_reduce = fleet.util().all_reduce(1, "sum", comm_world) all_reduce = fleet.util.all_reduce(1, "sum", comm_world)
self.assertEqual(1, all_reduce) self.assertEqual(1, all_reduce)
self.clean(tmp) self.clean(tmp)
...@@ -752,12 +752,12 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase): ...@@ -752,12 +752,12 @@ class TestGlooWithCloudRoleMaker(unittest.TestCase):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
comm_world = "server" comm_world = "server"
fleet.util().barrier(comm_world) fleet.util.barrier(comm_world)
gather = fleet.util().all_gather(1, comm_world) gather = fleet.util.all_gather(1, comm_world)
self.assertEqual(gather[0], 1) self.assertEqual(gather[0], 1)
all_reduce = fleet.util().all_reduce(1, "sum", comm_world) all_reduce = fleet.util.all_reduce(1, "sum", comm_world)
self.assertEqual(1, all_reduce) self.assertEqual(1, all_reduce)
self.clean(tmp) self.clean(tmp)
......
...@@ -22,7 +22,6 @@ import tempfile ...@@ -22,7 +22,6 @@ import tempfile
import os import os
import sys import sys
from paddle.dataset.common import download, DATA_HOME from paddle.dataset.common import download, DATA_HOME
from paddle.distributed.fleet.base.util_factory import fleet_util
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
...@@ -59,8 +58,7 @@ class TestFleetUtil(unittest.TestCase): ...@@ -59,8 +58,7 @@ class TestFleetUtil(unittest.TestCase):
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
default_util = fleet.util() self.assertNotEqual(fleet.util, None)
self.assertEqual(default_util, None)
def test_set_user_defined_util(self): def test_set_user_defined_util(self):
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
...@@ -76,17 +74,19 @@ class TestFleetUtil(unittest.TestCase): ...@@ -76,17 +74,19 @@ class TestFleetUtil(unittest.TestCase):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
my_util = UserDefinedUtil() my_util = UserDefinedUtil()
fleet.set_util(my_util) fleet.util = my_util
user_id = fleet.util().get_user_id() user_id = fleet.util.get_user_id()
self.assertEqual(user_id, 10) self.assertEqual(user_id, 10)
def test_fs(self): def test_fs(self):
from paddle.distributed.fleet.utils.fs import LocalFS import paddle.distributed.fleet as fleet
from paddle.distributed.fleet.utils import LocalFS
fs = LocalFS() fs = LocalFS()
dirs, files = fs.ls_dir("test_tmp") dirs, files = fs.ls_dir("test_tmp")
dirs, files = fs.ls_dir("./") dirs, files = fs.ls_dir("./")
self.assertFalse(fs.need_upload_download()) self.assertFalse(fs.need_upload_download())
fleet_util._set_file_system(fs) fleet.util._set_file_system(fs)
def download_files(self): def download_files(self):
path = download(self.proto_data_url, self.module_name, path = download(self.proto_data_url, self.module_name,
...@@ -98,7 +98,8 @@ class TestFleetUtil(unittest.TestCase): ...@@ -98,7 +98,8 @@ class TestFleetUtil(unittest.TestCase):
return unzip_folder return unzip_folder
def test_get_file_shard(self): def test_get_file_shard(self):
self.assertRaises(Exception, fleet_util.get_file_shard, "files") import paddle.distributed.fleet as fleet
self.assertRaises(Exception, fleet.util.get_file_shard, "files")
try: try:
import netifaces import netifaces
except: except:
...@@ -112,18 +113,20 @@ class TestFleetUtil(unittest.TestCase): ...@@ -112,18 +113,20 @@ class TestFleetUtil(unittest.TestCase):
role=role_maker.Role.WORKER, role=role_maker.Role.WORKER,
worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"], worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"],
server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"])
fleet_util._set_role_maker(role) fleet.init(role)
files = fleet_util.get_file_shard(["1", "2", "3"])
files = fleet.util.get_file_shard(["1", "2", "3"])
self.assertTrue(len(files) == 2 and "1" in files and "2" in files) self.assertTrue(len(files) == 2 and "1" in files and "2" in files)
def test_program_type_trans(self): def test_program_type_trans(self):
import paddle.distributed.fleet as fleet
data_dir = self.download_files() data_dir = self.download_files()
program_dir = os.path.join(data_dir, self.pruned_dir) program_dir = os.path.join(data_dir, self.pruned_dir)
text_program = "pruned_main_program.pbtxt" text_program = "pruned_main_program.pbtxt"
binary_program = "pruned_main_program.bin" binary_program = "pruned_main_program.bin"
text_to_binary = fleet_util._program_type_trans(program_dir, text_to_binary = fleet.util._program_type_trans(program_dir,
text_program, True) text_program, True)
binary_to_text = fleet_util._program_type_trans(program_dir, binary_to_text = fleet.util._program_type_trans(program_dir,
binary_program, False) binary_program, False)
self.assertTrue( self.assertTrue(
os.path.exists(os.path.join(program_dir, text_to_binary))) os.path.exists(os.path.join(program_dir, text_to_binary)))
...@@ -131,6 +134,7 @@ class TestFleetUtil(unittest.TestCase): ...@@ -131,6 +134,7 @@ class TestFleetUtil(unittest.TestCase):
os.path.exists(os.path.join(program_dir, binary_to_text))) os.path.exists(os.path.join(program_dir, binary_to_text)))
def test_prams_check(self): def test_prams_check(self):
import paddle.distributed.fleet as fleet
data_dir = self.download_files() data_dir = self.download_files()
class config: class config:
...@@ -160,11 +164,11 @@ class TestFleetUtil(unittest.TestCase): ...@@ -160,11 +164,11 @@ class TestFleetUtil(unittest.TestCase):
# test saved var's shape # test saved var's shape
conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match" conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match"
self.assertRaises(Exception, fleet_util._params_check) self.assertRaises(Exception, fleet.util._params_check)
# test program.proto without feed_op and fetch_op # test program.proto without feed_op and fetch_op
conf.dump_program_filename = "pruned_main_program.no_feed_fetch" conf.dump_program_filename = "pruned_main_program.no_feed_fetch"
results = fleet_util._params_check(conf) results = fleet.util._params_check(conf)
self.assertTrue(len(results) == 1) self.assertTrue(len(results) == 1)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
results[0], np.array( results[0], np.array(
...@@ -172,11 +176,11 @@ class TestFleetUtil(unittest.TestCase): ...@@ -172,11 +176,11 @@ class TestFleetUtil(unittest.TestCase):
# test feed_var's shape # test feed_var's shape
conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match" conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match"
self.assertRaises(Exception, fleet_util._params_check) self.assertRaises(Exception, fleet.util._params_check)
# test correct case with feed_vars_filelist # test correct case with feed_vars_filelist
conf.dump_program_filename = "pruned_main_program.pbtxt" conf.dump_program_filename = "pruned_main_program.pbtxt"
results = fleet_util._params_check(conf) results = fleet.util._params_check(conf)
self.assertTrue(len(results) == 1) self.assertTrue(len(results) == 1)
np.testing.assert_array_almost_equal( np.testing.assert_array_almost_equal(
results[0], np.array( results[0], np.array(
...@@ -186,13 +190,14 @@ class TestFleetUtil(unittest.TestCase): ...@@ -186,13 +190,14 @@ class TestFleetUtil(unittest.TestCase):
conf.feed_config.feeded_vars_filelist = None conf.feed_config.feeded_vars_filelist = None
# test feed var with lod_level >= 2 # test feed var with lod_level >= 2
conf.dump_program_filename = "pruned_main_program.feed_lod2" conf.dump_program_filename = "pruned_main_program.feed_lod2"
self.assertRaises(Exception, fleet_util._params_check) self.assertRaises(Exception, fleet.util._params_check)
conf.dump_program_filename = "pruned_main_program.pbtxt" conf.dump_program_filename = "pruned_main_program.pbtxt"
results = fleet_util._params_check(conf) results = fleet.util._params_check(conf)
self.assertTrue(len(results) == 1) self.assertTrue(len(results) == 1)
def test_proto_check(self): def test_proto_check(self):
import paddle.distributed.fleet as fleet
data_dir = self.download_files() data_dir = self.download_files()
class config: class config:
...@@ -210,7 +215,7 @@ class TestFleetUtil(unittest.TestCase): ...@@ -210,7 +215,7 @@ class TestFleetUtil(unittest.TestCase):
"pruned_main_program.save_var_shape_not_match")) "pruned_main_program.save_var_shape_not_match"))
conf.is_text_pruned_program = True conf.is_text_pruned_program = True
conf.draw = False conf.draw = False
res = fleet_util._proto_check(conf) res = fleet.util._proto_check(conf)
self.assertFalse(res) self.assertFalse(res)
# test match # test match
...@@ -222,10 +227,11 @@ class TestFleetUtil(unittest.TestCase): ...@@ -222,10 +227,11 @@ class TestFleetUtil(unittest.TestCase):
else: else:
conf.draw = True conf.draw = True
conf.draw_out_name = "pruned_check" conf.draw_out_name = "pruned_check"
res = fleet_util._proto_check(conf) res = fleet.util._proto_check(conf)
self.assertTrue(res) self.assertTrue(res)
def test_visualize(self): def test_visualize(self):
import paddle.distributed.fleet as fleet
if sys.platform == 'win32' or sys.platform == 'sys.platform': if sys.platform == 'win32' or sys.platform == 'sys.platform':
pass pass
else: else:
...@@ -234,10 +240,10 @@ class TestFleetUtil(unittest.TestCase): ...@@ -234,10 +240,10 @@ class TestFleetUtil(unittest.TestCase):
data_dir, data_dir,
os.path.join(self.train_dir, "join_main_program.pbtxt")) os.path.join(self.train_dir, "join_main_program.pbtxt"))
is_text = True is_text = True
program = fleet_util._load_program(program_path, is_text) program = fleet.util._load_program(program_path, is_text)
output_dir = os.path.join(data_dir, self.train_dir) output_dir = os.path.join(data_dir, self.train_dir)
output_filename = "draw_prog" output_filename = "draw_prog"
fleet_util._visualize_graphviz(program, output_dir, output_filename) fleet.util._visualize_graphviz(program, output_dir, output_filename)
self.assertTrue( self.assertTrue(
os.path.exists( os.path.exists(
os.path.join(output_dir, output_filename + ".dot"))) os.path.join(output_dir, output_filename + ".dot")))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册