diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 9bcd79cd34f07cb38ea28e1068bb6045cb82d27a..cc7d60b148def57ebfe3e4adf0e092c5d2e72779 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -33,22 +33,27 @@ message DistributedStrategy { optional int32 localsgd_k_step = 7 [ default = 4 ]; optional bool dgc = 8 [ default = false ]; optional bool hierachical_allreduce = 9 [ default = false ]; - optional int32 nccl_comm_num = 10 [ default = 1 ]; - optional bool gradient_merge = 11 [ default = false ]; - optional int32 gradient_merge_k_step = 12 [ default = 1 ]; - optional bool sequential_execution = 13 [ default = false ]; - optional bool enable_backward_optimizer_op_deps = 14 [ default = true ]; - optional bool lars = 15 [ default = false ]; - optional bool lamb = 16 [ default = false ]; - optional bool fuse_elewise_add_act_ops = 17 [ default = false ]; - optional bool fuse_bn_act_ops = 18 [ default = false ]; - optional bool enable_auto_fusion = 19 [ default = false ]; - optional bool fuse_relu_depthwise_conv = 20 [ default = false ]; - optional bool enable_inplace = 21 [ default = false ]; - optional bool fuse_all_reduce_ops = 22 [ default = false ]; - optional int32 num_iteration_per_drop_scope = 23 [ default = 1 ]; - optional bool sync_batch_norm = 24 [ default = false ]; - optional bool fuse_all_optimizer_ops = 25 [ default = false ]; + optional int32 hierachical_allreduce_inter_ranks = 10 [ default = 1 ]; + optional int32 nccl_comm_num = 11 [ default = 1 ]; + optional bool gradient_merge = 12 [ default = false ]; + optional int32 gradient_merge_k_step = 13 [ default = 1 ]; + optional bool sequential_execution = 14 [ default = false ]; + optional bool enable_backward_optimizer_op_deps = 15 [ default = true ]; + optional bool lars = 16 [ default = false ]; + optional bool lamb = 17 [ default = false ]; + optional bool fuse_elewise_add_act_ops = 18 [ default = false ]; + optional bool fuse_bn_act_ops = 19 [ default = false ]; + optional bool enable_auto_fusion = 20 [ default = false ]; + optional bool fuse_relu_depthwise_conv = 21 [ default = false ]; + optional bool enable_inplace = 22 [ default = false ]; + optional bool fuse_all_reduce_ops = 23 [ default = false ]; + optional int32 num_iteration_per_drop_scope = 24 [ default = 1 ]; + optional bool sync_batch_norm = 25 [ default = false ]; + optional bool fuse_all_optimizer_ops = 26 [ default = false ]; + optional bool sync_nccl_allreduce = 27 [ default = true ]; + optional bool fuse_broadcast_ops = 28 [ default = true ]; + optional int32 num_threads = 29 [ default = 1 ]; + optional int32 num_iteration_per_run = 30 [ default = 1 ]; // pipeline training optional bool pipeline = 101 [ default = false ]; diff --git a/python/paddle/fleet/__init__.py b/python/paddle/fleet/__init__.py index a5a8d12ed440077714a59773e1c870848e9de229..b25c362ce9301c122d2e2b6915e444da6a90ceca 100644 --- a/python/paddle/fleet/__init__.py +++ b/python/paddle/fleet/__init__.py @@ -14,10 +14,29 @@ # TODO: define distributed api under this directory, from .base.distributed_strategy import DistributedStrategy -#from .base.role_maker import PaddleCloudRoleMaker, UserDefinedRoleMaker -#from .base.fleet_base import Fleet +from .base.fleet_base import Fleet +from .base.util_factory import UtilBase -#__all__ = [ -# "DistributedStrategy", "PaddleCloudRoleMaker", "UserDefinedRoleMaker" -#] -__all__ = ['DistributedStrategy'] +#from .base.role_maker import PaddleCloudRoleMaker + +__all__ = ["DistributedStrategy", "UtilBase"] + +fleet = Fleet() +init = fleet.init +is_first_worker = fleet.is_first_worker +worker_index = fleet.worker_index +worker_num = fleet.worker_num +is_worker = fleet.is_worker +worker_endpoints = fleet.worker_endpoints +server_num = fleet.server_num +server_index = fleet.server_index +server_endpoints = fleet.server_endpoints +is_server = fleet.is_server +util = fleet.util +barrier_worker = fleet.barrier_worker +init_worker = fleet.init_worker +init_server = fleet.init_server +run_server = fleet.run_server +stop_worker = fleet.stop_worker +distributed_optimizer = fleet.distributed_optimizer +minimize = fleet.minimize diff --git a/python/paddle/fleet/base/distributed_strategy.py b/python/paddle/fleet/base/distributed_strategy.py index 0ebaff3a0f70c734b97b1da509fdaa0b080c5e3f..fdc5b22ae4c62d96564ad6e56d3e2b4822971e19 100644 --- a/python/paddle/fleet/base/distributed_strategy.py +++ b/python/paddle/fleet/base/distributed_strategy.py @@ -14,6 +14,7 @@ from paddle.fleet.proto import distributed_strategy_pb2 from paddle.fluid.framework import Variable +import google.protobuf.text_format class DistributedJobInfo(object): @@ -57,6 +58,15 @@ class DistributedStrategy(object): def __init__(self): self.strategy = distributed_strategy_pb2.DistributedStrategy() + def save_to_prototxt(self, output): + with open(output, "w") as fout: + fout.write(str(self.strategy)) + + def load_from_prototxt(self, pb_file): + f = open(pb_file, 'r') + self.strategy = google.protobuf.text_format.Merge( + str(f.read()), self.strategy) + @property def amp(self): return self.strategy.amp @@ -189,6 +199,19 @@ class DistributedStrategy(object): print( "WARNING: hierachical_allreduce should have value of bool type") + @property + def hierachical_allreduce_inter_ranks(self): + return self.strategy.hierachical_allreduce_inter_ranks + + @hierachical_allreduce_inter_ranks.setter + def hierachical_allreduce_inter_ranks(self, flag): + if isinstance(flag, bool): + self.strategy.hierachical_allreduce_inter_ranks = flag + else: + print( + "WARNING: hierachical_allreduce_inter_ranks should have value of bool type" + ) + @property def nccl_comm_num(self): return self.strategy.nccl_comm_num @@ -235,6 +258,17 @@ class DistributedStrategy(object): print( "WARNING: sequential_execution should have value of bool type") + @property + def sync_nccl_allreduce(self): + return self.strategy.sync_nccl_allreduce + + @sync_nccl_allreduce.setter + def sync_nccl_allreduce(self, flag): + if isinstance(flag, bool): + self.strategy.sync_nccl_allreduce = flag + else: + print("WARNING: sync_nccl_allreduce should have avlue of bool type") + @property def lars(self): return self.strategy.lars @@ -305,6 +339,17 @@ class DistributedStrategy(object): "WARNING: fuse_relu_depthwise_conv should have value of bool type" ) + @property + def fuse_broadcast_ops(self): + return self.strategy.fuse_broadcast_ops + + @fuse_broadcast_ops.setter + def fuse_broadcast_ops(self, flag): + if isinstance(flag, bool): + self.strategy.fuse_broadcast_ops = flag + else: + print("WARNING: fuse_broadcast_ops should have value of bool type") + @property def enable_inplace(self): return self.strategy.enable_inplace @@ -340,6 +385,18 @@ class DistributedStrategy(object): "WARNING: num_iteration_per_drop_scope should have value of int type" ) + @property + def num_iteration_per_run(self): + return self.strategy.num_iteration_per_run + + @num_iteration_per_run.setter + def num_iteration_per_run(self, value): + if isinstance(value, int): + self.strategy.num_iteration_per_run = value + else: + print( + "WARNING: num_iteration_per_run should have value of int type") + @property def sync_batch_norm(self): return self.strategy.sync_batch_norm @@ -499,6 +556,17 @@ class DistributedStrategy(object): else: print("WARNING: elastic should have value of bool type") + @property + def num_threads(self): + return self.strategy.num_threads + + @num_threads.setter + def num_threads(self, value): + if isinstance(value, int): + self.strategy.num_threads = value + else: + print("WARNING: num_threads should have value of int type") + @property def auto(self): return self.strategy.auto diff --git a/python/paddle/fleet/base/fleet_base.py b/python/paddle/fleet/base/fleet_base.py index 881044006479e074283c645c5247efa08c3b37b9..46d06e5d026cae0cb2fcc78360b2b3d0faf0acd9 100644 --- a/python/paddle/fleet/base/fleet_base.py +++ b/python/paddle/fleet/base/fleet_base.py @@ -13,7 +13,330 @@ # limitations under the License. from __future__ import print_function -from paddle.fleet import RoleMakerBase -from . import obj_creator +import paddle +from .strategy_compiler import StrategyCompiler +from .meta_optimizer_factory import MetaOptimizerFactory +from .runtime_factory import RuntimeFactory +from .util_factory import UtilFactory -# __all__ = ['Fleet'] +__all__ = ['Fleet'] + + +class Fleet(object): + """ + Unified API for distributed training of PaddlePaddle + Please reference the https://github.com/PaddlePaddle/Fleet for details + + + Returns: + Fleet: A Fleet instance + + Examples: + .. code-block:: python + + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + if fleet.is_first_worker(): + print("this is first worker") + print("current node index: {}".format(fleet.worker_index())) + print("total number of worker num: {}".format(fleet.worker_num())) + if fleet.is_worker(): + print("this is worker") + print("worker endpoints: {}".format(fleet.worker_endpoints(to_string=True))) + print("server num: {}".format(fleet.server_num())) + print("server endpoints: {}".format(fleet.server_endpoints(to_string=True))) + if fleet.is_server(): + print("this is server") + fleet.stop_worker() + """ + + def __init__(self): + self._runtime_handle = None + self._util = None + + def init(self, role_maker): + self._role_maker = role_maker + self.strategy_compiler = StrategyCompiler() + + def is_first_worker(self): + """ + Check whether the node is the first instance of worker. + + Returns: + bool: True if this is the first node of worker, + False if not. + + """ + return self._role_maker.is_first_worker() + + def worker_index(self): + """ + Get current worker index. + + Returns: + int: node id + """ + return self._role_maker.worker_index() + + def worker_num(self): + """ + Get current total worker number. + + Returns: + int: worker numbers + """ + return self._role_maker.worker_num() + + def is_worker(self): + """ + Check whether the node is an instance of worker. + + Returns: + bool: True if this is a node of worker, + False if not. + """ + return self._role_maker.is_worker() + + def worker_endpoints(self, to_string=False): + """ + Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"]. + + Returns: + list/string: server endpoints + """ + ''' + if to_string: + return ",".join(self._role_maker.get_trainer_endpoints()) + else: + return self._role_maker.get_trainer_endpoints() + ''' + return ["127.0.0.1:1001", "127.0.0.1:1002"] + + def server_num(self): + """ + Get current total worker number. + + Returns: + int: server number + """ + return len(self._role_maker.get_pserver_endpoints()) + + def server_index(self): + """ + Get current server index. + + Returns: + int: node id + """ + return self._role_maker.server_index() + + def server_endpoints(self, to_string=False): + """ + Get current server endpoints, such as ["127.0.0.1:1001", "127.0.0.1:1002"]. + + Returns: + list/string: server endpoints + """ + ''' + if to_string: + return ",".join(self._role_maker.get_pserver_endpoints()) + else: + return self._role_maker.get_pserver_endpoints() + ''' + return ["127.0.0.1:1001", "127.0.0.1:1002"] + + def is_server(self): + """ + Check whether the node is an instance of server. + + Returns: + bool: True if this is a node of server, + False if not. + """ + return self._role_maker.is_server() + + @property + def util(self): + """ + Utility functions that can be used under certain runtime + return util + """ + return self._util + + @util.setter + def util(self, util): + """ + Set Utility functions for userd-defined runtime + set util + """ + self._util = util + + def barrier_worker(self): + """ + barrier between workers + """ + self._role_maker.barrier_worker() + + def init_worker(self): + """ + init worker + """ + assert self._runtime_handle is not None + self._runtime_handle._init_worker() + + def init_server(self, model_dir=None): + """ + init server + """ + assert self._runtime_handle is not None + self._runtime_handle._init_server() + + def run_server(self): + """ + run server + """ + assert self._runtime_handle is not None + self._runtime_handle._run_server() + + def stop_worker(self): + """ + stop worker + """ + assert self._runtime_handle is not None + self._runtime_handle._stop_worker() + + def distributed_optimizer(self, optimizer, strategy): + """ + distirbuted_optimizer + Returns: + Fleet instance with minimize interface like optimizers + + Examples: + .. code-block:: python + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + """ + self.user_defined_optimizer = optimizer + self.user_defined_strategy = strategy + return self + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + """ + Add distributed operations to minimize ``loss`` by updating ``parameter_list``. + + Args: + loss (Variable): A ``Variable`` containing the value to minimize. + startup_program (Program, optional): :ref:`api_fluid_Program` for + initializing parameters in ``parameter_list``. The default value + is None, at this time :ref:`api_fluid_default_startup_program` will be used. + parameter_list (Iterable, optional): Iterable of ``Variable`` or ``Variable.name`` to update + to minimize ``loss``. The default value is None, at this time all parameters + will be updated. + no_grad_set (set, optional): Set of ``Variable`` or ``Variable.name`` that don't need + to be updated. The default value is None. + + Returns: + tuple: tuple (optimize_ops, params_grads), A list of operators appended + by minimize and a list of (param, grad) variable pairs, param is + ``Parameter``, grad is the gradient value corresponding to the parameter. + The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to + indicate program pruning. If so, the program will be pruned by ``feed`` and + ``fetch_list`` before run, see details in ``Executor``. + + Examples: + import paddle + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + + fc_1 = paddle.layers.fc(input=input_x, size=hid_dim, act='tanh') + fc_2 = paddlen.layers.fc(input=fc_1, size=hid_dim, act='tanh') + prediction = paddle.layers.fc(input=[fc_2], size=label_dim, act='softmax') + cost = paddle.layers.cross_entropy(input=prediction, label=input_y) + avg_cost = paddle.layers.mean(x=cost) + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + # for more examples, please reference https://github.com/PaddlePaddle/Fleet + + """ + # cache original feed forward program + self.origin_main_program = loss.block.program + if startup_program == None: + self.origin_startup_program = \ + paddle.default_startup_program().clone(for_test=False) + startup_program = paddle.default_startup_program() + else: + self.origin_startup_program = \ + startup_program.clone(for_test=False) + + # compile time + distributed_optimizer_list = \ + MetaOptimizerFactory()._get_valid_meta_optimizers( + self.user_defined_optimizer) + valid_optimizer_list = [] + valid_graph_optimizer_list = [] + # recall meta optimizers for ranking + for opt in distributed_optimizer_list: + opt._set_basic_info(loss, self._role_maker, + self.user_defined_optimizer, + self.user_defined_strategy) + if opt._can_apply() and not opt._is_graph_out(): + valid_optimizer_list.append(opt) + if opt._can_apply() and opt._is_graph_out(): + valid_graph_optimizer_list.append(opt) + # combine recalled meta optimizers to be a valid meta optimizer + meta_optimizer, graph_optimizer, final_dist_strategy = \ + self.strategy_compiler.generate_optimizer( + loss, self._role_maker, self.user_defined_optimizer, + self.user_defined_strategy, valid_optimizer_list, + valid_graph_optimizer_list) + optimize_ops = [] + params_grads = [] + if meta_optimizer: + optimize_ops, params_grads = meta_optimizer.minimize( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + + if graph_optimizer: + optimizer_ops, params_grads = graph_optimizer.minimize( + loss, + startup_program=startup_program, + parameter_list=parameter_list, + no_grad_set=no_grad_set) + # since we do not encourage users to use graph operations + # if a graph optimizer takes effect, mostly + # optimizers_ops and params_grads are None + # i.e. users can not modify current computation graph anymore + + if self._runtime_handle is None: + self._runtime_handle = RuntimeFactory()._create_runtime( + final_dist_strategy, self._role_maker, optimize_ops, + params_grads) + + if self._util is None: + self._util = UtilFactory()._create_util(final_dist_strategy, + self._role_maker, + optimize_ops, params_grads) + + return optimize_ops, params_grads diff --git a/python/paddle/fleet/base/obj_creator.py b/python/paddle/fleet/base/meta_optimizer_factory.py similarity index 54% rename from python/paddle/fleet/base/obj_creator.py rename to python/paddle/fleet/base/meta_optimizer_factory.py index 15a403d79edcf7210863b624074827494684c38a..8d42c2a0c89ef629449cdd61f57261188a6499ca 100644 --- a/python/paddle/fleet/base/obj_creator.py +++ b/python/paddle/fleet/base/meta_optimizer_factory.py @@ -12,12 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from util_base import UtilBase +from ..meta_optimizers import RecomputeOptimizer +from ..meta_optimizers import GraphExecutionOptimizer +__all__ = ["MetaOptimizerFactory"] -def _create_fleet_obj_from_role_maker(role_maker): - pass +meta_optimizer_names = ["RecomputeOptimizer", "GraphExecutionOptimizer"] -def _create_fleet_util_from_role_maker(role_maker): - pass +class MetaOptimizerFactory(object): + def __init__(self): + pass + + def _get_valid_meta_optimizers(self, user_defined_optimizer): + opt_list = [] + for opt_name in meta_optimizer_names: + opt_list.append(globals()[opt_name](user_defined_optimizer)) + return opt_list diff --git a/python/paddle/fleet/base/private_helper_function.py b/python/paddle/fleet/base/private_helper_function.py new file mode 100644 index 0000000000000000000000000000000000000000..6b3232b93b22416982d86d80db4530627bb2493a --- /dev/null +++ b/python/paddle/fleet/base/private_helper_function.py @@ -0,0 +1,55 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +import time +import socket +from contextlib import closing +from six import string_types + + +def wait_server_ready(endpoints): + """ + Wait until parameter servers are ready, use connext_ex to detect + port readiness. + + Args: + endpoints (list): endpoints string list, like: + ["127.0.0.1:8080", "127.0.0.1:8081"] + + Examples: + .. code-block:: python + + wait_server_ready(["127.0.0.1:8080", "127.0.0.1:8081"]) + """ + assert not isinstance(endpoints, str) + while True: + all_ok = True + not_ready_endpoints = [] + for ep in endpoints: + ip_port = ep.split(":") + with closing(socket.socket(socket.AF_INET, + socket.SOCK_STREAM)) as sock: + sock.settimeout(2) + result = sock.connect_ex((ip_port[0], int(ip_port[1]))) + if result != 0: + all_ok = False + not_ready_endpoints.append(ep) + if not all_ok: + sys.stderr.write("server not ready, wait 3 sec to retry...\n") + sys.stderr.write("not ready endpoints:" + str(not_ready_endpoints) + + "\n") + sys.stderr.flush() + time.sleep(3) + else: + break diff --git a/python/paddle/fleet/base/runtime_factory.py b/python/paddle/fleet/base/runtime_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d42db4ea993d9241222d42595e2c0d6af0a2d7 --- /dev/null +++ b/python/paddle/fleet/base/runtime_factory.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from ..runtime.collective_runtime import CollectiveRuntime + + +class RuntimeFactory(object): + def __init__(self): + pass + + def _create_runtime(self, final_dist_strategy, role_maker, opt_ops, + params_grads): + if role_maker._is_collective: + collective_runtime = CollectiveRuntime() + collective_runtime._set_basic_info(final_dist_strategy, role_maker, + opt_ops, params_grads) + return collective_runtime diff --git a/python/paddle/fleet/base/strategy_compiler.py b/python/paddle/fleet/base/strategy_compiler.py new file mode 100644 index 0000000000000000000000000000000000000000..92b50781f65ba928a47f5d4c0ecda2d739739c56 --- /dev/null +++ b/python/paddle/fleet/base/strategy_compiler.py @@ -0,0 +1,69 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def maximum_path_len_algo(optimizer_list): + max_idx = 0 + max_len = 0 + candidates = [] + for idx, opt in enumerate(optimizer_list): + local_buffer = [opt] + for opt_inner in optimizer_list: + if opt._can_update(opt_inner): + local_buffer.append(opt_inner) + if len(local_buffer) > max_len: + max_idx = idx + max_len = len(local_buffer) + candidates.append(local_buffer) + if len(candidates) == 0: + return None + for idx, opt in enumerate(candidates[max_idx][:-1]): + opt._update_inner_optimizer(candidates[max_idx][idx + 1]) + return candidates[max_idx][0] + + +class StrategyCompilerBase(object): + def __init__(self): + pass + + +class StrategyCompiler(StrategyCompilerBase): + """ + StrategyCompiler is responsible for meta optimizers combination + Generally, a user can define serveral distributed strategies that + can generate serveral meta optimizer. The combination of these + meta optimizers should have the right order to apply the optimizers' + minimize function. + This class is responsible for the executable distributed optimizer + generation. + """ + + def __init__(self): + super(StrategyCompiler, self).__init__() + + def generate_optimizer(self, loss, role_maker, optimizer, + userd_defined_strategy, meta_optimizer_list, + graph_optimizer_list): + if len(meta_optimizer_list) == 0 and len(graph_optimizer_list) == 0: + return optimizer, None + else: + # currently, we use heuristic algorithm to select + # meta optimizers combinations + meta_optimizer = maximum_path_len_algo(meta_optimizer_list) + graph_optimizer = maximum_path_len_algo(graph_optimizer_list) + # should design a distributed strategy update interface + # when we have finally decided the combination of meta_optimizer + # and graph_optimizer, the corresponding distributed strategy + # should be updated. + return meta_optimizer, graph_optimizer, None diff --git a/python/paddle/fleet/base/util_base.py b/python/paddle/fleet/base/util_factory.py similarity index 71% rename from python/paddle/fleet/base/util_base.py rename to python/paddle/fleet/base/util_factory.py index 7654d0bcd9cd657ab79e9acf74b8fdfb72c489de..74029f43d10c86dadb052000884fa9df7a667f72 100644 --- a/python/paddle/fleet/base/util_base.py +++ b/python/paddle/fleet/base/util_factory.py @@ -16,13 +16,30 @@ """basic collective operations in python""" """remote file system""" -# __all__ = ['UtilBase'] -''' +__all__ = ['UtilBase'] + + +class UtilFactory(object): + def _create_util(self, dist_strategy, role_maker, optimize_ops, + params_grads): + util = UtilBase() + util._set_strategy(dist_strategy) + util._set_role_maker(role_maker) + return util + + class UtilBase(object): - def __init__(self, role_maker, fleet_obj): - self.role_maker = roke_maker - self.fleet_obj = fleet_obj + def __init__(self): + self.role_maker = None + self.dist_strategy = None + + def _set_strategy(self, dist_strategy): + self.dist_strategy = dist_strategy + + def _set_role_maker(self, role_maker): + self.role_maker = role_maker + ''' def set_file_system(self, fs_client): self.fs_client = fs_client @@ -61,4 +78,4 @@ class UtilBase(object): def print_on_rank(self): pass -''' + ''' diff --git a/python/paddle/fleet/collective/__init__.py b/python/paddle/fleet/meta_optimizers/__init__.py similarity index 79% rename from python/paddle/fleet/collective/__init__.py rename to python/paddle/fleet/meta_optimizers/__init__.py index 8647330f3290f3142cabca9a7e3fe162a9838dda..8a87a31e903894fbfe476d86a03b899ac150d360 100644 --- a/python/paddle/fleet/collective/__init__.py +++ b/python/paddle/fleet/meta_optimizers/__init__.py @@ -10,3 +10,8 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and + +from .recompute_optimizer import RecomputeOptimizer +from .graph_execution_optimizer import GraphExecutionOptimizer + +__all__ = ['RecomputeOptimizer'] diff --git a/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..13d62a6d462a2b76a609739a63785922221bc3b0 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/graph_execution_optimizer.py @@ -0,0 +1,194 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import paddle +from paddle.fluid.framework import core +from paddle.fluid import compiler +from .meta_optimizer_base import MetaOptimizerBase +from ..base.private_helper_function import wait_server_ready + + +def get_build_strategy(dist_strategy): + build_strategy = paddle.BuildStrategy() + build_strategy.enable_sequential_execution = \ + dist_strategy.sequential_execution + build_strategy.remove_unnecessary_lock = True + build_strategy.fuse_elewise_add_act_ops = \ + dist_strategy.fuse_elewise_add_act_ops + build_strategy.fuse_bn_act_ops = \ + dist_strategy.fuse_bn_act_ops + build_strategy.enable_auto_fusion = \ + dist_strategy.enable_auto_fusion + build_strategy.fuse_relu_depthwise_conv = \ + dist_strategy.fuse_relu_depthwise_conv + build_strategy.fuse_broadcast_ops = \ + dist_strategy.fuse_broadcast_ops + build_strategy.sync_batch_norm = \ + dist_strategy.sync_batch_norm + return build_strategy + + +def get_execution_strategy(dist_strategy): + execution_strategy = paddle.ExecutionStrategy() + execution_strategy.num_threads = \ + dist_strategy.num_threads + execution_strategy.num_iteration_per_drop_scope = \ + dist_strategy.num_iteration_per_drop_scope + execution_strategy.num_iteration_per_run = \ + dist_strategy.num_iteration_per_run + execution_strategy.use_thread_barrier = \ + dist_strategy.use_thread_barrier + return execution_strategy + + +class GraphExecutionOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(GraphExecutionOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + + def _is_graph_out(self): + return True + + def _can_apply(self): + """ + Basically, this is PE, and almost all programs can be executed here + """ + return True + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + pass + + # should fix the variable + def _setup_nccl_op(self, startup_program, main_program): + trainer_endpoints = self.role_maker.get_trainer_endpoints() + trainers = trainer_endpoints + trainer_id = self.role_maker.worker_index() + current_endpoint = self.role_maker.get_trainer_endpoints()[trainer_id] + trainer_endpoints_env = ",".join(trainer_endpoints) + trainers_num = self.role_maker.worker_num() + trainer_endpoints.remove(current_endpoint) + if trainer_id == 0: + wait_server_ready(trainer_endpoints) + nccl_id_var = startup_program.global_block().create_var( + name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW) + for i in range(1, self.user_defined_strategy.nccl_comm_num): + startup_program.global_block().create_var( + name="NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + if self.user_defined_strategy.hierachical_allreduce: + for i in range(0, self.user_defined_strategy.nccl_comm_num): + startup_program.global_block().create_var( + name="Hierarchical_inter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + startup_program.global_block().create_var( + name="Hierarchical_exter_NCCLID_{}".format(i), + persistable=True, + type=core.VarDesc.VarType.RAW) + + startup_program.global_block().append_op( + type="gen_nccl_id", + inputs={}, + outputs={"NCCLID": nccl_id_var}, + attrs={ + "trainers": trainers, + "trainer_id": trainer_id, + "nccl_comm_num": self.user_defined_strategy.nccl_comm_num, + "use_hierarchical_allreduce": + self.user_defined_strategy.hierachical_allreduce, + "hierarchical_allreduce_inter_ranks": + self.user_defined_strategy.hierachical_allreduce_inter_ranks + }) + + def _try_to_compile(self, startup_program, main_program, loss): + build_strategy = get_build_strategy(self.user_defined_strategy) + exe_strategy = get_execution_strategy(self.user_defined_strategy) + node_num = self.role_maker.worker_num() + if self.role_maker._is_collective: + assert node_num >= 1, "nccl2 node_num must >= 1, now:{}" % node_num + + if node_num <= 1: + # local mode + if self.user_defined_strategy.nccl_comm_num > 1: + logging.warn("set nccl_comm_num=1 since you only have 1 node.") + self.user_defined_strategy.nccl_comm_num = 1 + + if self.user_defined_strategy.hierachical_allreduce: + logging.warn( + "set hierachical_allreduce=False since you only have 1 node." + ) + self.user_defined_strategy.hierachical_allreduce = False + + sync_allreduce = self.user_defined_strategy.sync_nccl_allreduce + if sync_allreduce: + exe_strategy.num_threads = self.user_defined_strategy.nccl_comm_num + 1 + if self.user_defined_strategy.hierachical_allreduce: + exe_strategy.num_threads = 2 * self.user_defined_strategy.nccl_comm_num + 1 + if exe_strategy.num_threads > 4: + logging.warn( + "if you use hierachical_allreduce or " + "with multi nccl comm, please export FLAGS_sync_nccl_allreduce = 0" + ) + + # TODO(guru4elephant): should be an independent optimizer + sync_batch_norm = self.user_defined_strategy.sync_batch_norm + if sync_batch_norm: + self.user_defined_strategy.nccl_comm_num = 1 + self.user_defined_strategy.hierachical_allreduce = False + exe_strategy.num_threads = 1 + logging.warn( + "use sync_batch_norm will hang when set num_threads > 1, so " + "set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False." + ) + + # TODO(guru4elephant): should be an independent optimizer + self._setup_nccl_op(startup_program, main_program) + + build_strategy.num_trainers = self.role_maker.worker_num() + build_strategy.trainer_id = self.role_maker.worker_index() + build_strategy.trainers_endpoints = self.role_maker.get_trainer_endpoints( + ) + build_strategy.enable_backward_optimizer_op_deps = True + + self._compiled_program = compiler.CompiledProgram(main_program) + + self._compiled_program.with_data_parallel( + loss_name=loss.name, + build_strategy=build_strategy, + exec_strategy=exe_strategy, + share_vars_from=None) + + return self._compiled_program + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + if startup_program == None: + startup_program = paddle.default_startup_program() + compiled_program = self._try_to_compile(startup_program, + loss.block.program, loss) + loss.block.program.graph = compiled_program + + # just return self.optimizer_ops and self.param_grads + return None, None diff --git a/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py b/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py new file mode 100644 index 0000000000000000000000000000000000000000..33b7b2bb1e85269d1d91bbc50f996f3e56a84435 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/meta_optimizer_base.py @@ -0,0 +1,56 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = ["MetaOptimizerBase"] + + +class MetaOptimizerBase(object): + def __init__(self, optimizer): + pass + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + self.loss = loss + self.role_maker = role_maker + self.user_defined_optimizer = user_defined_optimizer + self.user_defined_strategy = user_defined_strategy + + def _update_inner_optimier(self, optimizer): + self.inner_opt = optimizer + + def _can_apply(self): + return False + + def _is_graph_out(self): + return False + + def _can_update(self, optimizer): + if str(optimizer.__class__.__name__) in self.meta_optimizers_white_list: + return True + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + raise NotImplementedError("meta optimizer not implemented") + + def minimize(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + optimize_ops, params_grads = self.minimize_impl( + loss, startup_program, parameter_list, no_grad_set) + return optimize_ops, params_grads diff --git a/python/paddle/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/fleet/meta_optimizers/recompute_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..902b8367b34f65e95a3501d5586db056ef5a55f3 --- /dev/null +++ b/python/paddle/fleet/meta_optimizers/recompute_optimizer.py @@ -0,0 +1,59 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from paddle.fluid.optimizer import RecomputeOptimizer as RO +from .meta_optimizer_base import MetaOptimizerBase + +__all__ = ["RecomputeOptimizer"] + + +class RecomputeOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(RecomputeOptimizer, self).__init__(optimizer) + #self.inner_opt = RO(optimizer) + self.inner_opt = optimizer + self.wrapped_opt = RO(optimizer) + # we do not allow meta optimizer to be inner optimizer currently + self.meta_optimizers_white_list = [] + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(RecomputeOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.wrapped_opt._set_checkpoints([]) + + def _can_apply(self): + if self.user_defined_strategy.recompute == True: + if len(self.user_defined_strategy.recompute_checkpoints) == 0: + return False + else: + return True + + def backward(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None, + callbacks=None): + return self.wrapped_opt.backward(loss, startup_program, parameter_list, + no_grad_set, callbacks) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + optimize_ops, params_grads = \ + self.wrapped_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + return optimize_ops, params_grads diff --git a/python/paddle/fleet/parameter_server/__init__.py b/python/paddle/fleet/runtime/__init__.py similarity index 78% rename from python/paddle/fleet/parameter_server/__init__.py rename to python/paddle/fleet/runtime/__init__.py index 847ddc47ac89114f2012bc6b9990a69abfe39fb3..f38287cf51a728011d16f735e58ec54a7cdfe0c8 100644 --- a/python/paddle/fleet/parameter_server/__init__.py +++ b/python/paddle/fleet/runtime/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .collective_runtime import CollectiveRuntime + +__all__ = ["CollectiveRuntime"] diff --git a/python/paddle/fleet/runtime/collective_runtime.py b/python/paddle/fleet/runtime/collective_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..0881c4b52c822908cedc94d3f4de088eed6c65e8 --- /dev/null +++ b/python/paddle/fleet/runtime/collective_runtime.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .runtime_base import RuntimeBase +import logging + + +class CollectiveRuntime(RuntimeBase): + def __init__(self): + super(CollectiveRuntime, self).__init__() + + def _init_worker(self): + logging.warn( + "You should not call 'init_worker' method for collective mode.") + pass + + def _run_worker(self): + logging.warn( + "You should not call 'run_worker' method for collective mode.") + pass + + def _init_server(self): + logging.warn( + "You should not call 'init_server' method for collective mode.") + pass + + def _run_server(self): + logging.warn( + "You should not call 'run_server' method for collective mode.") + pass + + def _stop_worker(self): + logging.warn( + "You should not call 'stop_worker' method for collective mode.") + pass + + # save inference model should be added here diff --git a/python/paddle/fleet/runtime/runtime_base.py b/python/paddle/fleet/runtime/runtime_base.py new file mode 100644 index 0000000000000000000000000000000000000000..5610a5305a464e39e9ab5a6bb7594e5e225a12ba --- /dev/null +++ b/python/paddle/fleet/runtime/runtime_base.py @@ -0,0 +1,38 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__all__ = [] + + +class RuntimeBase(object): + def __init__(self): + pass + + def _set_basic_info(self, loss, role_maker, optimizer, strategy): + self.loss = loss + self.role_maker = role_maker + self.optimizer = optimizer + self.strategy = strategy + + def _run_worker(self): + pass + + def _init_server(self): + pass + + def _run_server(self): + pass + + def _stop_worker(self): + pass diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c95577561f45158ce4de80753e8f3725cd8673e0..aae15850d1cf8479e3159125a389cd6d5c26eeb1 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -32,6 +32,9 @@ list(APPEND MIXED_DIST_TEST_OPS test_communicator_sync) list(APPEND MIXED_DIST_TEST_OPS test_fleet_api_input) list(APPEND MIXED_DIST_TEST_OPS test_fleet_checkpoint) list(APPEND MIXED_DIST_TEST_OPS test_collective_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_base) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_meta_optimizer) +list(APPEND MIXED_DIST_TEST_OPS test_fleet_private_function) foreach(TEST_OP ${MIXED_DIST_TEST_OPS}) list(REMOVE_ITEM TEST_OPS ${TEST_OP}) endforeach() @@ -339,6 +342,11 @@ if(WITH_DISTRIBUTE) py_test_modules(test_communicator_half_async MODULES test_communicator_half_async ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1) py_test_modules(test_communicator_sync MODULES test_communicator_sync ENVS ${dist_ENVS} FLAGS_communicator_send_queue_size=1 FLAGS_communicator_max_merge_var_num=1) py_test_modules(test_collective_optimizer MODULES test_collective_optimizer) + if(NOT APPLE) + py_test_modules(test_fleet_base MODULES test_fleet_base ENVS ${dist_ENVS}) + py_test_modules(test_fleet_meta_optimizer MODULES test_fleet_meta_optimizer ENVS ${dist_ENVS}) + py_test_modules(test_fleet_private_function MODULES test_fleet_private_function ENVS ${dist_ENVS}) + endif(NOT APPLE) if(WITH_DGC) # if with dgc, test all dgc tests. # NOTE. dist dgc tests is already in DIST_TEST_OPS diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base.py b/python/paddle/fluid/tests/unittests/test_fleet_base.py new file mode 100644 index 0000000000000000000000000000000000000000..20542da3f05ec84b51dee8a9c5913bb20630f4a2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_base.py @@ -0,0 +1,177 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + + +class TestFleetBase(unittest.TestCase): + def setUp(self): + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def test_init(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + + def test_is_first_worker(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_first_worker(): + print("test fleet first worker done.") + + def test_worker_index(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + print(fleet.worker_index()) + + def test_worker_num(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + print(fleet.worker_num()) + + def test_is_worker(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_worker(): + print("test fleet is worker") + + def test_worker_endpoints(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + print(fleet.worker_endpoints(to_string=True)) + + def test_server_num(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_server(): + print("fleet server num: {}".format(fleet.server_num())) + + def test_server_index(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_server(): + print("fleet server index: {}".format(fleet.server_index())) + + def test_server_endpoints(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_server(): + print("fleet server index: {}".format( + fleet.server_endpoints(to_string=True))) + + def test_is_server(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_server(): + print("test fleet is server") + + def test_util(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + self.assertEqual(fleet.util, None) + + def test_barrier_worker(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_worker(): + fleet.barrier_worker() + + def test_init_worker(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_worker(): + fleet.init_worker() + + def test_run_server(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_worker(): + fleet.run_worker() + + def test_stop_worker(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + if fleet.is_worker(): + fleet.stop_worker() + + def test_distributed_optimizer(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + + def test_minimize(self): + import paddle + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = fleet.DistributedStrategy() + optimizer = paddle.optimizer.SGD(learning_rate=0.001) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py index 0668546a703bc00369d55e12d4b03c934c9315c2..bac03176c8da1d46485c93af0fc824f2bcda97b3 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_distributed_strategy.py @@ -109,6 +109,13 @@ class TestStrategyConfig(unittest.TestCase): strategy.hierachical_allreduce = "True" self.assertEqual(strategy.hierachical_allreduce, False) + def test_hierachical_allreduce_inter_ranks(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.hierachical_allreduce_inter_ranks = 1 + self.assertEqual(strategy.hierachical_allreduce_inter_ranks, 1) + strategy.hierachical_allreduce_inter_ranks = "2" + self.assertEqual(strategy.hierachical_allreduce_inter_ranks, 1) + def test_nccl_comm_num(self): strategy = paddle.fleet.DistributedStrategy() strategy.nccl_comm_num = 1 @@ -220,6 +227,13 @@ class TestStrategyConfig(unittest.TestCase): strategy.num_iteration_per_drop_scope = 0.1 self.assertEqual(strategy.num_iteration_per_drop_scope, 1) + def test_num_iteration_per_run(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.num_iteration_per_run = 1 + self.assertEqual(strategy.num_iteration_per_run, 1) + strategy.num_iteration_per_run = 0.1 + self.assertEqual(strategy.num_iteration_per_run, 1) + def test_sync_batch_norm(self): strategy = paddle.fleet.DistributedStrategy() strategy.sync_batch_norm = True @@ -336,6 +350,40 @@ class TestStrategyConfig(unittest.TestCase): strategy.auto = "True" self.assertEqual(strategy.auto, False) + def test_sync_nccl_allreduce(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.sync_nccl_allreduce = True + self.assertEqual(strategy.sync_nccl_allreduce, True) + strategy.sync_nccl_allreduce = False + self.assertEqual(strategy.sync_nccl_allreduce, False) + strategy.sync_nccl_allreduce = "True" + self.assertEqual(strategy.sync_nccl_allreduce, False) + + def test_fuse_broadcast_ops(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.fuse_broadcast_ops = True + self.assertEqual(strategy.fuse_broadcast_ops, True) + strategy.fuse_broadcast_ops = False + self.assertEqual(strategy.fuse_broadcast_ops, False) + strategy.fuse_broadcast_ops = "True" + self.assertEqual(strategy.fuse_broadcast_ops, False) + + def test_num_threads(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.num_threads = 1 + self.assertEqual(strategy.num_threads, 1) + strategy.num_threads = 0.1 + self.assertEqual(strategy.num_threads, 1) + + def test_strategy_prototxt(self): + strategy = paddle.fleet.DistributedStrategy() + strategy.sync_nccl_allreduce = True + strategy.save_to_prototxt("dist_strategy.prototxt") + strategy2 = paddle.fleet.DistributedStrategy() + strategy2.load_from_prototxt("dist_strategy.prototxt") + self.assertEqual(strategy.sync_nccl_allreduce, + strategy2.sync_nccl_allreduce) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..9cb300f83d9c4ced68f69f06170a32f82f039d8c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_meta_optimizer.py @@ -0,0 +1,76 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + + +class TestFleetMetaOptimizer(unittest.TestCase): + def setUp(self): + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ["PADDLE_PSERVERS_IP_PORT_LIST"] = \ + "127.0.0.1:36001,127.0.0.2:36001" + + def test_graph_execution_optimizer(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.fleet.DistributedStrategy() + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + def test_recompute_optimizer(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.fluid.layers.data( + name="x", shape=[32], dtype='float32') + input_y = paddle.fluid.layers.data(name="y", shape=[1], dtype='int64') + + fc_1 = paddle.fluid.layers.fc(input=input_x, size=64, act='tanh') + fc_2 = paddle.fluid.layers.fc(input=fc_1, size=64, act='tanh') + prediction = paddle.fluid.layers.fc(input=[fc_2], size=2, act='softmax') + cost = paddle.fluid.layers.cross_entropy( + input=prediction, label=input_y) + avg_cost = paddle.fluid.layers.mean(x=cost) + + strategy = paddle.fleet.DistributedStrategy() + strategy.recompute = True + strategy.recompute_checkpoints = [fc_2] + + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_private_function.py b/python/paddle/fluid/tests/unittests/test_fleet_private_function.py new file mode 100644 index 0000000000000000000000000000000000000000..ec99acf109816570db48d9f15bbbdd897133006a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_private_function.py @@ -0,0 +1,47 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import os +import paddle +import socket +import threading + + +class TestFleetPrivateFunction(unittest.TestCase): + def test_wait_port(self): + def init_server(port): + import time + time.sleep(5) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", port)) + sock.listen(10) + while True: + c, addr = sock.accept() + c.send("0") + c.close() + break + + thr = threading.Thread(target=init_server, args=(9292, )) + thr.start() + + import paddle.fleet as fleet + ep = ["127.0.0.1:9292"] + fleet.base.private_helper_function.wait_server_ready(ep) + + thr.join() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_runtime.py b/python/paddle/fluid/tests/unittests/test_fleet_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..474e5da1c219c4b6e5a35a59ee235fdcbdb34cce --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_runtime.py @@ -0,0 +1,40 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + + +class TestFleetRuntime(unittest.TestCase): + def test_fleet_runtime_base(self): + import paddle.fleet.runtime + base = paddle.fleet.runtime.runtime_base.RuntimeBase() + base._run_worker() + base._init_server() + base._run_server() + base._stop_worker() + + def test_fleet_collective_runtime(self): + import paddle.fleet.runtime + collective_runtime = paddle.fleet.runtime.CollectiveRuntime() + collective_runtime._init_worker() + collective_runtime._run_worker() + collective_runtime._init_worker() + collective_runtime._run_server() + collective_runtime._stop_worker() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_util.py b/python/paddle/fluid/tests/unittests/test_fleet_util.py new file mode 100644 index 0000000000000000000000000000000000000000..4825035d123df1767fe7845b2515f7d42253446c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_util.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import os + + +class TestFleetUtil(unittest.TestCase): + def test_util_base(self): + import paddle.fleet as fleet + util = fleet.UtilBase() + strategy = fleet.DistributedStrategy() + util._set_strategy(strategy) + role_maker = None # should be fleet.PaddleCloudRoleMaker() + util._set_role_maker(role_maker) + + def test_util_factory(self): + import paddle.fleet as fleet + factory = fleet.base.util_factory.UtilFactory() + strategy = fleet.DistributedStrategy() + role_maker = None # should be fleet.PaddleCloudRoleMaker() + optimize_ops = [] + params_grads = [] + util = factory._create_util(strategy, role_maker, optimize_ops, + params_grads) + self.assertEqual(util.role_maker, None) + + def test_get_util(self): + import paddle.fleet as fleet + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + default_util = fleet.util + self.assertEqual(default_util, None) + + def test_set_user_defined_util(self): + import paddle.fleet as fleet + + class UserDefinedUtil(fleet.UtilBase): + def __init__(self): + super(UserDefinedUtil, self).__init__() + + def get_user_id(self): + return 10 + + import paddle.fluid.incubate.fleet.base.role_maker as role_maker + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + my_util = UserDefinedUtil() + fleet.util = my_util + user_id = fleet.util.get_user_id() + self.assertEqual(user_id, 10) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/setup.py.in b/python/setup.py.in index ba61499d254f4850a89ec04a3cdcef0f8d5cb9d9..f3dc1035fc14df2297752eae9c49a4a21f7579a4 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -145,10 +145,10 @@ packages=['paddle', 'paddle.incubate.complex.tensor', 'paddle.fleet', 'paddle.fleet.base', - 'paddle.fleet.collective', + 'paddle.fleet.meta_optimizers', + 'paddle.fleet.runtime', 'paddle.fleet.dataset', 'paddle.fleet.metrics', - 'paddle.fleet.parameter_server', 'paddle.fleet.proto', 'paddle.framework', 'paddle.fluid',