diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index a24cc4cb55c0491ed9be0298e4fbac4f2434b6d0..df482f43346c57cc59af42936b6a7308b76cbd3a 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -97,7 +97,6 @@ message AsyncConfig { optional int32 thread_pool_size = 6 [ default = 1 ]; optional int32 send_wait_times = 7 [ default = 1 ]; optional bool runtime_split_send_recv = 8 [ default = false ]; - optional string heter_worker_device = 9 [ default = 'cpu' ]; } message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } diff --git a/paddle/fluid/operators/distributed/parameter_recv.cc b/paddle/fluid/operators/distributed/parameter_recv.cc index 4d5f048cf7c89ce55acdec11cb184fcdedeac744..5409ec54987fbb7ad89f61cc1655a4c3ef302ac0 100644 --- a/paddle/fluid/operators/distributed/parameter_recv.cc +++ b/paddle/fluid/operators/distributed/parameter_recv.cc @@ -112,6 +112,10 @@ void RecvSelectedRows(const CommContext &rpc_ctx, template void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto cpu_place = platform::CPUPlace(); + auto &cpu_ctx = *pool.Get(cpu_place); + distributed::RPCClient *rpc_client = distributed::RPCClient::GetInstance(rpc_ctx.trainer_id); @@ -121,14 +125,10 @@ void RecvLodTensor(const CommContext &rpc_ctx, const framework::Scope &scope) { if (rpc_ctx.origin_varnames.size() == 1 && rpc_ctx.splited_varnames.size() == 1) { auto varname = rpc_ctx.origin_varnames[0]; - const auto place = - scope.FindVar(varname)->Get().place(); - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - auto &ctx = *pool.Get(place); - VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0] << " in gpu? " - << platform::is_gpu_place(place); - rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], ctx, + VLOG(4) << "recv " << varname << " from " << rpc_ctx.epmap[0]; + rets.push_back(rpc_client->AsyncGetVarNoBarrier(rpc_ctx.epmap[0], cpu_ctx, scope, varname, varname)); + for (size_t i = 0; i < rets.size(); i++) { PADDLE_ENFORCE_NE( rets[i]->Wait(), 0U, diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index e86006469f3024c1bdd55841b8f765a9252aeaab..f1c836468daf36db753c67a3e09757be728d37a7 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -107,7 +107,7 @@ class DistributedStrategy(object): All of the distributed training configurations can be configured in DistributedStrategy, such as automatic mixed precision (AMP), Layer-wise Adaptive Rate Scaling (LARS), asynchronous update parameter server(ASGD), etc. - + DistributedStrategy can be serialized into protobuf file or deserialized from protobuf file Users who run local training usually configure BuildStrategy and ExecutionStrategy, and @@ -129,7 +129,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.dgc = True @@ -207,7 +207,7 @@ class DistributedStrategy(object): build_strategy.fuse_broadcast_ops = True build_strategy.fuse_all_optimizer_ops = True build_strategy.enable_inplace = True - + strategy = paddle.distributed.fleet.DistributedStrategy() strategy.build_strategy = build_strategy """ @@ -248,7 +248,7 @@ class DistributedStrategy(object): strategy = fleet.DistributedStrategy() strategy.a_sync = True # by default this is True - + # code block for defining loss and local optimizer # sgd = fleet.distributed_optimizer(optimizer, strategy) """ @@ -259,7 +259,7 @@ class DistributedStrategy(object): def a_sync(self, flag): if isinstance(flag, bool): self.strategy.a_sync = flag - self.a_sync_configs = {"k_steps": 0, "worker_device": 'cpu'} + self.a_sync_configs = {"k_steps": 0} else: raise ValueError( "The type of `flag` is invalid, expected type is bool, but received %s". @@ -472,7 +472,7 @@ class DistributedStrategy(object): def sync_batch_norm(self): """ Indicating whether we are using sync_batch_norm to do synchronous batch normalization among all training nodes. - + Default value: False Examples: @@ -525,7 +525,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.fuse_grad_size_in_MB = 50 @@ -563,7 +563,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.nccl_comm_num = 2 @@ -595,7 +595,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.recompute = True @@ -621,7 +621,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.pipeline = True @@ -656,7 +656,7 @@ class DistributedStrategy(object): Examples: .. code-block:: python - + import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.pipeline = True @@ -971,7 +971,7 @@ class DistributedStrategy(object): [Large Batch Optimization for Deep Learning: Training BERT in 76 minutes](https://arxiv.org/abs/1904.00962). Default Value: False - + Examples: .. code-block:: python @@ -1114,7 +1114,7 @@ class DistributedStrategy(object): optimizer = paddle.optimizer.SGD(learning_rate=0.01) optimizer = fleet.distributed_optimizer(optimizer, strategy) - + """ return self.strategy.conv_workspace_size_limit diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py index 62224d35fe1a836646015045110df6374bcbef59..dfa765364f357b6e685c3983c73cfb4f1b2cce61 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_graph_optimizer.py @@ -31,10 +31,6 @@ class ParameterServerGraphOptimizer(ParameterServerOptimizer): if k_steps < 0: return False - device = self.user_defined_strategy.a_sync_configs["worker_device"] - if device.upper() != 'CPU': - return False - if self.role_maker._is_server(): return False diff --git a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py index a7b35452b5ac22ed2de8ebda11cad65e4b5ec12d..38ad41f8836b4e8c3b304dbf539b47d5293a8221 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/parameter_server_optimizer.py @@ -13,7 +13,6 @@ from paddle import fluid from .meta_optimizer_base import MetaOptimizerBase -from ..base.private_helper_function import wait_server_ready from paddle.fluid import core import subprocess import re @@ -75,8 +74,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): _startup = worker.delet_extra_optimizes_pass(_startup, compiled_config) - compiled_config.set_origin_ps_main_program(_main) - compiled_config.set_origin_ps_startup_program(_startup) # for heter program if self.role_maker._is_heter_parameter_server_mode: from paddle.fluid.incubate.fleet.parameter_server.ir import heter_trainer_pass as heter_worker @@ -94,16 +91,6 @@ class ParameterServerOptimizer(MetaOptimizerBase): else: _main = worker.append_send_ops_pass(_main, compiled_config) _startup = _startup - compiled_config.set_origin_ps_main_program(_main) - compiled_config.set_origin_ps_startup_program(_startup) - - # for trainer wait server ready - wait_server_ready(self.role_maker._get_pserver_endpoints()) - - # for ps-heter mode, wait heter worker ready - if self.role_maker._is_heter_parameter_server_mode and self.role_maker._is_worker( - ): - wait_server_ready(self.role_maker._get_heter_worker_endpoints()) return _main, _startup diff --git a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py index 1db9da908a22c96dccba96a0c907ab86f3595d27..ae5c53b8a37c4958e58ed5b09ce7cc8194f1ff52 100644 --- a/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py +++ b/python/paddle/distributed/fleet/runtime/parameter_server_runtime.py @@ -94,8 +94,8 @@ class ParameterServerRuntime(RuntimeBase): return False if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ - var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ - var.desc.type() == core.VarDesc.VarType.READER: + var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ + var.desc.type() == core.VarDesc.VarType.READER: return False return var.persistable @@ -198,21 +198,16 @@ class ParameterServerRuntime(RuntimeBase): warnings.warn("communicator has been initialized, skip") def _get_executor(self): - if self.role_maker._is_heter_worker(): - heter_worker_device = self.context["valid_strategy"].a_sync_configs[ - "heter_worker_device"].upper() - if heter_worker_device == "GPU": + if self.role_maker._get_heter_worker_device() == "GPU": gpu_id = int(os.getenv("FLAGS_selected_gpus", "0")) executor = Executor(fluid.CUDAPlace(gpu_id)) - elif heter_worker_device == "XPU": + elif self.role_maker._get_heter_worker_device() == "XPU": xpu_id = int(os.getenv("FLAGS_selected_xpus", "0")) executor = Executor(fluid.XPUPlace(xpu_id)) - elif heter_worker_device == "CPU": - fluid.Executor(fluid.CPUPlace()) else: - raise ValueError("Heter Worker Not Support Device {}".format( - heter_worker_device)) + raise ValueError("Not Support Device {}".format( + self.role_maker._get_heter_worker_device())) else: executor = fluid.Executor(fluid.CPUPlace()) return executor @@ -317,7 +312,7 @@ class ParameterServerRuntime(RuntimeBase): opts = _get_optimize_ops(self.origin_main_program) for op in opts: if "Param" in op.input_names and \ - "LearningRate" in op.input_names and op.input("Param")[0] == param_name: + "LearningRate" in op.input_names and op.input("Param")[0] == param_name: return op def _save_dense_params(self, executor, dirname, context, main_program): @@ -463,13 +458,13 @@ class ParameterServerRuntime(RuntimeBase): def _save_distributed_persistables(self, executor, dirname, main_program): dense_ctx = self.compiled_strategy.get_communicator_recv_context( - recv_type=1, use_origin_program=True) + recv_type=1) sparse_ctx = self.compiled_strategy.get_communicator_recv_context( - recv_type=2, use_origin_program=True) + recv_type=2) distributed_ctx = self.compiled_strategy.get_communicator_recv_context( - recv_type=3, use_origin_program=True) + recv_type=3) recv_dense_varnames = self._save_dense_params(executor, dirname, dense_ctx, main_program) @@ -521,7 +516,7 @@ class ParameterServerRuntime(RuntimeBase): ) if main_program is None: - main_program = self.compiled_strategy.get_origin_ps_main_program() + main_program = fluid.default_main_program() if isinstance(main_program, CompiledProgram): raise TypeError( diff --git a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py index 90847382c86e1c0bfd2cd9fae33342cbdb38e5ce..e348c67ae0461674358fa6d34ee8a73648862a6d 100644 --- a/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py +++ b/python/paddle/fluid/incubate/fleet/parameter_server/ir/public.py @@ -133,8 +133,6 @@ class CompileTimeStrategy(object): self.origin_main_program = main_program self.origin_startup_program = startup_program - self.origin_ps_main_program = main_program - self.origin_ps_startup_program = startup_program self.strategy = strategy self.role_maker = role_maker @@ -155,11 +153,6 @@ class CompileTimeStrategy(object): self._build_var_distributed() - # for heter-ps save variables - self.origin_merged_variables_pairs = list(self.merged_variables_pairs) - self.origin_merged_dense_pairs = list(self.merged_dense_pairs) - self.origin_merged_sparse_pairs = list(self.merged_sparse_pairs) - def get_distributed_mode(self): trainer = self.strategy.get_trainer_runtime_config() return trainer.mode @@ -221,18 +214,6 @@ class CompileTimeStrategy(object): def get_origin_startup_program(self): return self.origin_startup_program - def set_origin_ps_main_program(self, program): - self.origin_ps_main_program = program - - def set_origin_ps_startup_program(self, program): - self.origin_ps_startup_program = program - - def get_origin_ps_main_program(self): - return self.origin_ps_main_program - - def get_origin_ps_startup_program(self): - return self.origin_ps_startup_program - def get_sparse_varname_on_ps(self, is_distributed, endpoint=None): if not endpoint: endpoint = self.get_ps_endpoint() @@ -397,9 +378,7 @@ class CompileTimeStrategy(object): send_ctx[name] = ctx return send_ctx - def get_communicator_recv_context(self, - recv_type=1, - use_origin_program=False): + def get_communicator_recv_context(self, recv_type=1): # recv_type # 1 : DENSE 2. SPARSE 3. DISTRIBUTED 4. ALL distibuted_varnames = get_sparse_tablenames(self.origin_main_program, @@ -413,8 +392,7 @@ class CompileTimeStrategy(object): sparse_recv_ctx = {} distributed_recv_ctx = {} - variables_pairs = self.merged_variables_pairs if not use_origin_program else self.origin_merged_variables_pairs - for merged in variables_pairs: + for merged in self.merged_variables_pairs: params = merged[0] if params.merged_var.name in sparse_varnames: continue