From 1404f73213dc7135f84294f8618ed34d8d5fdda4 Mon Sep 17 00:00:00 2001 From: kangguangli Date: Mon, 13 Mar 2023 11:01:09 +0800 Subject: [PATCH] [with_data_parallel][part6] remove with_data_parallel in distributed optimizer (#50719) * find relevant testcase * remove with_data_parallel * trigger CI * do not apply ParameterServerGraphOptimizer * remove useless optimizer * remove with_data_parallel in test_dist_base * fix test_fleet_base_3 * only reserve changes for GraphExecutionOptimizer * fix bug * fix test_minst_dgc_nccl * fix typo * fix test_dist_mnist_gradient_merge * rm TestDistMnistNCCL2DGCMultiCards * fix optimizer conflicts * fix dist_mnist * fix test_dist_hapi * delete test_fleet_graph_execution_meta_optimizer & test_fleet_graph_executor * temporally not delete unittest * fix unittests * fix ci * recover prune in python/paddle/hapi/model.py --- .../framework/distributed_strategy.proto | 2 +- .../fleet/meta_optimizers/__init__.py | 1 - .../fleet/meta_optimizers/amp_optimizer.py | 1 - .../fleet/meta_optimizers/asp_optimizer.py | 1 - .../fp16_allreduce_optimizer.py | 1 - .../gradient_merge_optimizer.py | 1 - .../graph_execution_optimizer.py | 275 ------------------ .../fleet/meta_optimizers/lamb_optimizer.py | 2 +- .../fleet/meta_optimizers/lars_optimizer.py | 2 +- .../meta_optimizers/localsgd_optimizer.py | 2 - .../meta_optimizers/pipeline_optimizer.py | 4 +- .../meta_optimizers/raw_program_optimizer.py | 10 +- .../meta_optimizers/recompute_optimizer.py | 1 - .../meta_optimizers/sharding_optimizer.py | 4 +- .../tensor_parallel_optimizer.py | 4 +- .../unittests/collective/fleet/CMakeLists.txt | 9 +- ...st_fleet_graph_execution_meta_optimizer.py | 32 +- .../fleet/test_fleet_graph_executor.py | 93 +++--- .../tests/unittests/test_fleet_base_3.py | 4 +- python/paddle/hapi/model.py | 9 + 20 files changed, 101 insertions(+), 357 deletions(-) delete mode 100644 python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 27bc7c7030d..5f5e5a3fac5 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -335,7 +335,7 @@ message DistributedStrategy { optional float last_comm_group_size_MB = 27 [ default = 1 ]; optional bool find_unused_parameters = 28 [ default = false ]; optional bool tensor_parallel = 29 [ default = false ]; - optional bool without_graph_optimization = 30 [ default = false ]; + optional bool without_graph_optimization = 30 [ default = true ]; optional int32 fuse_grad_size_in_num = 31 [ default = 8 ]; optional bool calc_comm_same_stream = 32 [ default = false ]; optional bool asp = 33 [ default = false ]; diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index 1e98b3432f0..b2b6a87c526 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -16,7 +16,6 @@ from .amp_optimizer import AMPOptimizer from .asp_optimizer import ASPOptimizer from .recompute_optimizer import RecomputeOptimizer from .gradient_merge_optimizer import GradientMergeOptimizer -from .graph_execution_optimizer import GraphExecutionOptimizer from .ps_optimizer import ParameterServerOptimizer from .pipeline_optimizer import PipelineOptimizer from .localsgd_optimizer import LocalSGDOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index de0aca0aea6..26e6f9065c4 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -28,7 +28,6 @@ class AMPOptimizer(MetaOptimizerBase): "LarsOptimizer", "LambOptimizer", "RecomputeOptimizer", - "GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["DGCOptimizer"] diff --git a/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py index 96b38e39395..65d0590dd49 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/asp_optimizer.py @@ -29,7 +29,6 @@ class ASPOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", - "GraphExecutionOptimizer", "RecomputeOptimizer", "GradientMergeOptimizer", ] diff --git a/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py index 0c08066ea54..618465d401b 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/fp16_allreduce_optimizer.py @@ -31,7 +31,6 @@ class FP16AllReduceOptimizer(MetaOptimizerBase): "RecomputeOptimizer", "LocalSGDOptimizer", "GradientMergeOptimizer", - "GraphExecutionOptimizer", "AdaptiveLocalSGDOptimizer", ] self.meta_optimizers_black_list = ["DGCOptimizer"] diff --git a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py index 524761a01f0..858949b6a44 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/gradient_merge_optimizer.py @@ -27,7 +27,6 @@ class GradientMergeOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", - "GraphExecutionOptimizer", "RecomputeOptimizer", ] self.meta_optimizers_black_list = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py deleted file mode 100644 index a7610ded24a..00000000000 --- a/python/paddle/distributed/fleet/meta_optimizers/graph_execution_optimizer.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -import copy -import logging - -import paddle -from paddle.framework import core -from paddle.static import BuildStrategy - -from ..base.private_helper_function import wait_server_ready -from .meta_optimizer_base import MetaOptimizerBase - -__all__ = [] - - -class GraphExecutionOptimizer(MetaOptimizerBase): - def __init__(self, optimizer): - super().__init__(optimizer) - self.inner_opt = optimizer - # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] - self.meta_optimizers_black_list = [] - - def _is_graph_out(self): - return True - - def _can_apply(self): - """ - Basically, this is PE, and almost all programs can be executed here - """ - if not self.role_maker._is_collective: - # update me. currently, if parameter server is used - # graph execution optimizer can not be applied - return False - return not self.user_defined_strategy.without_graph_optimization - - def backward( - self, - loss, - startup_program=None, - parameter_list=None, - no_grad_set=None, - callbacks=None, - ): - pass - - # should fix the variable - def _setup_nccl_op(self, startup_program, main_program, build_strategy): - trainer_endpoints = self.role_maker._get_trainer_endpoints() - other_trainers = copy.copy(trainer_endpoints) - - trainer_id = self.role_maker._worker_index() - current_endpoint = self.role_maker._get_trainer_endpoints()[trainer_id] - other_trainers.remove(current_endpoint) - - trainer_endpoints_env = ",".join(trainer_endpoints) - trainers_num = self.role_maker._worker_num() - - # NOTE(wangxi): npu don't need to wait server ready - if trainer_id == 0 and not paddle.is_compiled_with_npu(): - wait_server_ready(other_trainers) - - if core.is_compiled_with_cuda(): - comm_id_var = startup_program.global_block().create_var( - name="NCCLID", persistable=True, type=core.VarDesc.VarType.RAW - ) - - for i in range(1, build_strategy.nccl_comm_num): - startup_program.global_block().create_var( - name="NCCLID_{}".format(i), - persistable=True, - type=core.VarDesc.VarType.RAW, - ) - - if build_strategy.use_hierarchical_allreduce: - for i in range(0, build_strategy.nccl_comm_num): - startup_program.global_block().create_var( - name="Hierarchical_inter_NCCLID_{}".format(i), - persistable=True, - type=core.VarDesc.VarType.RAW, - ) - startup_program.global_block().create_var( - name="Hierarchical_exter_NCCLID_{}".format(i), - persistable=True, - type=core.VarDesc.VarType.RAW, - ) - - startup_program.global_block().append_op( - type="gen_nccl_id", - inputs={}, - outputs={"NCCLID": comm_id_var}, - attrs={ - "trainers": trainer_endpoints, - "trainer_id": trainer_id, - "nccl_comm_num": build_strategy.nccl_comm_num, - "use_hierarchical_allreduce": build_strategy.use_hierarchical_allreduce, - "hierarchical_allreduce_inter_ranks": build_strategy.hierarchical_allreduce_inter_nranks, - }, - ) - elif core.is_compiled_with_xpu(): - comm_id_var = startup_program.global_block().create_var( - name="BKCLID", persistable=True, type=core.VarDesc.VarType.RAW - ) - - # NOTE(liuyuhui) Baidu Kunlun Communication Library(BKCL) currently do not support multi machines. - assert ( - build_strategy.bkcl_comm_num == 1 - ), "Baidu Kunlun Communication Library(BKCL) currently do not support multi machines." - for i in range(1, build_strategy.bkcl_comm_num): - startup_program.global_block().create_var( - name="BKCLID_{}".format(i), - persistable=True, - type=core.VarDesc.VarType.RAW, - ) - - startup_program.global_block().append_op( - type="gen_bkcl_id", - inputs={}, - outputs={"BKCLID": comm_id_var}, - attrs={ - "trainers": trainer_endpoints, - "trainer_id": trainer_id, - "bkcl_comm_num": build_strategy.bkcl_comm_num, - "use_hierarchical_allreduce": build_strategy.use_hierarchical_allreduce, - "hierarchical_allreduce_inter_ranks": build_strategy.hierarchical_allreduce_inter_nranks, - }, - ) - else: - raise ValueError( - "comm_id must be generated in paddlepaddle-xpu or paddlepaddle-gpu." - ) - - def _try_to_compile(self, startup_program, main_program, loss): - dist_strategy = self.user_defined_strategy - local_build_strategy = dist_strategy.build_strategy - - local_build_strategy.use_hierarchical_allreduce = ( - dist_strategy.use_hierarchical_allreduce - ) - local_build_strategy.hierarchical_allreduce_inter_nranks = ( - dist_strategy.hierarchical_allreduce_inter_nranks - ) - local_build_strategy.sync_batch_norm = dist_strategy.sync_batch_norm - local_build_strategy.fuse_all_reduce_ops = ( - dist_strategy.fuse_all_reduce_ops - ) - local_build_strategy.nccl_comm_num = dist_strategy.nccl_comm_num - - gradient_scale_configs = ( - self.user_defined_strategy.gradient_scale_configs - ) - scale_strategys = { - 'avg': BuildStrategy.GradientScaleStrategy.CoeffNumDevice, - 'sum': BuildStrategy.GradientScaleStrategy.One, - 'customized': BuildStrategy.GradientScaleStrategy.Customized, - } - assert ( - gradient_scale_configs['scale_strategy'] in scale_strategys - ), "gradient_scale_configs.scale_strategy must be 'avg', 'sum' or 'customized'" - local_build_strategy.gradient_scale_strategy = scale_strategys[ - gradient_scale_configs['scale_strategy'] - ] - - if self.user_defined_strategy.recompute: - logging.warn( - "set enable_sequential_execution=True since you have enable the recompute strategy" - ) - local_build_strategy.enable_sequential_execution = True - - exe_strategy = self.user_defined_strategy.execution_strategy - worker_num = self.role_maker._worker_num() - node_num = self.role_maker._node_num() - - if self.role_maker._is_collective: - assert worker_num >= 1, ( - "nccl2 worker_num must >= 1, now:{}" % worker_num - ) - - if worker_num <= 1: - # local mode - if local_build_strategy.nccl_comm_num > 1: - logging.warn("set nccl_comm_num=1 since you only have 1 node.") - local_build_strategy.nccl_comm_num = 1 - - if node_num <= 1: - if local_build_strategy.use_hierarchical_allreduce: - logging.warn( - "set hierachical_allreduce=False since you only have 1 node." - ) - local_build_strategy.use_hierarchical_allreduce = False - - sync_allreduce = dist_strategy.sync_nccl_allreduce - if sync_allreduce: - exe_strategy.num_threads = max( - local_build_strategy.nccl_comm_num + 1, exe_strategy.num_threads - ) - if local_build_strategy.nccl_comm_num > 1: - logging.warn( - "nccl_comm_num > 1, you may need to set sync_nccl_allreduce=False to ensure that different nccl comms can overlap" - ) - - sync_batch_norm = local_build_strategy.sync_batch_norm - if sync_batch_norm: - local_build_strategy.nccl_comm_num = 1 - local_build_strategy.use_hierarchical_allreduce = False - exe_strategy.num_threads = 1 - logging.warn( - "use sync_batch_norm will hang when set num_threads > 1, so " - "set num_threads=1, nccl_comm_num=1, hierachical_allreduce=False." - ) - - # NOTE. compatible with compiler, otherwise these values will be overwritten by compiler - main_program._nccl_comm_num = local_build_strategy.nccl_comm_num - main_program._use_hierarchical_allreduce = ( - local_build_strategy.use_hierarchical_allreduce - ) - main_program._hierarchical_allreduce_inter_nranks = ( - local_build_strategy.hierarchical_allreduce_inter_nranks - ) - - # TODO(guru4elephant): should be an independent optimizer - if worker_num > 1: - self._setup_nccl_op( - startup_program, main_program, local_build_strategy - ) - - local_build_strategy.num_trainers = self.role_maker._worker_num() - local_build_strategy.trainer_id = self.role_maker._worker_index() - local_build_strategy.trainers_endpoints = ( - self.role_maker._get_trainer_endpoints() - ) - local_build_strategy.enable_backward_optimizer_op_deps = True - - self._compiled_program = paddle.static.CompiledProgram(main_program) - - self._compiled_program.with_data_parallel( - loss_name=loss.name, - build_strategy=local_build_strategy, - exec_strategy=exe_strategy, - share_vars_from=None, - ) - - return self._compiled_program - - def _disable_strategy(self, dist_strategy): - # TODO(guru4elephant): should close all PE related flags here - return - - def _enable_strategy(self, dist_strategy, context): - # by default, graph execution strategy is enabled - return - - def minimize( - self, loss, startup_program=None, parameter_list=None, no_grad_set=None - ): - if startup_program is None: - startup_program = paddle.static.default_startup_program() - compiled_program = self._try_to_compile( - startup_program, loss.block.program, loss - ) - loss.block.program._graph = compiled_program - - # just return self.optimizer_ops and self.param_grads - return None, None diff --git a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py index 9a0ccde5979..1a8c491fe48 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lamb_optimizer.py @@ -27,7 +27,7 @@ class LambOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.lamb_opt = None # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = ["GraphExecutionOptimizer"] + self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] def _set_basic_info( diff --git a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py index b58bdd446c2..a81305cecaf 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/lars_optimizer.py @@ -26,7 +26,7 @@ class LarsOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.lars_opt = None # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = ["GraphExecutionOptimizer"] + self.meta_optimizers_white_list = [] self.meta_optimizers_black_list = [] def _set_basic_info( diff --git a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py index e9e9f353cfd..2973d4d3130 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/localsgd_optimizer.py @@ -31,7 +31,6 @@ class LocalSGDOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.meta_optimizers_white_list = ['AMPOptimizer'] self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", "AdaptiveLocalSGDOptimizer", ] self.snapshot_key = '@SNAPSHOT' @@ -215,7 +214,6 @@ class AdaptiveLocalSGDOptimizer(MetaOptimizerBase): self.inner_opt = optimizer self.meta_optimizers_white_list = ['AMPOptimizer'] self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", "LocalSGDOptimizer", ] self.snapshot_key = '@SNAPSHOT' diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index a501b625688..40a038cdeb7 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -35,9 +35,7 @@ class PipelineOptimizer(MetaOptimizerBase): "RecomputeOptimizer", "AMPOptimizer", ] - self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", - ] + self.meta_optimizers_black_list = [] self.global_ring_id = 1 self.dp_ring_id = 2 self.start_pipeline_ring_id = 20 # Just a magic number diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index 5ca078a06b9..e7abec03fc2 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -40,9 +40,7 @@ class RawProgramOptimizer(MetaOptimizerBase): "DGCOptimizer", "LocalSGDOptimizer", ] - self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", - ] + self.meta_optimizers_black_list = [] self.global_ring_id = 0 def _set_basic_info( @@ -66,6 +64,12 @@ class RawProgramOptimizer(MetaOptimizerBase): def _can_apply(self): if not self.role_maker._is_collective: return False + if self.user_defined_strategy.tensor_parallel: + return False + if self.user_defined_strategy.sharding: + return False + if self.user_defined_strategy.amp: + return False if self.without_graph_optimization: return True diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index f43d1779c19..7c7fbecf700 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -27,7 +27,6 @@ class RecomputeOptimizer(MetaOptimizerBase): self.meta_optimizers_white_list = [ "LarsOptimizer", "LambOptimizer", - "GraphExecutionOptimizer", "DGCOptimizer", ] self.meta_optimizers_black_list = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index dedcc6f5ac7..980cd283a46 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -73,9 +73,7 @@ class ShardingOptimizer(MetaOptimizerBase): # "ModelParallelOptimizer", # "PipelineOptimizer", ] - self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", - ] + self.meta_optimizers_black_list = [] self._main_program = None self._startup_program = None self._segments = [] diff --git a/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py index 59058bb4b27..2d74d6454a0 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/tensor_parallel_optimizer.py @@ -37,9 +37,7 @@ class TensorParallelOptimizer(MetaOptimizerBase): "LarsOptimizer", "LambOptimizer", ] - self.meta_optimizers_black_list = [ - "GraphExecutionOptimizer", - ] + self.meta_optimizers_black_list = [] self.mp_ring_id = 0 self.global_ring_id = 1 self.dp_ring_id = 2 diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt index c61705bdc15..e47c72e46cf 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/collective/fleet/CMakeLists.txt @@ -173,8 +173,13 @@ if((WITH_GPU OR WITH_ASCEND_CL ) AND LOCAL_ALL_PLAT) - py_test_modules( - test_fleet_graph_executor MODULES test_fleet_graph_executor ENVS + bash_test_modules( + test_fleet_graph_executor + START_BASH + ../../dist_test.sh + LABELS + "RUN_TYPE=DIST" + ENVS "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") endif() if((WITH_GPU) AND LOCAL_ALL_PLAT) diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_execution_meta_optimizer.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_execution_meta_optimizer.py index a36c5a1d74c..f41bd7de9fd 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_execution_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_execution_meta_optimizer.py @@ -17,10 +17,6 @@ import unittest from launch_function_helper import _find_free_port, launch_func, wait -import paddle - -paddle.enable_static() - class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): def setUp(self): @@ -43,6 +39,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "0", } node_b = { @@ -54,9 +51,13 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "1", } def node_func(): + import paddle + + paddle.enable_static() import paddle.distributed.fleet as fleet fleet.init(is_collective=True) @@ -85,7 +86,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ) optimizer.minimize(avg_cost) - exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace()) + exe = paddle.fluid.Executor() exe.run(paddle.fluid.default_startup_program()) proc_a = launch_func(node_func, node_a) @@ -107,6 +108,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "0", } node_b = { @@ -118,9 +120,13 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "1", } def node_func(): + import paddle + + paddle.enable_static() import paddle.distributed.fleet as fleet fleet.init(is_collective=True) @@ -150,7 +156,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): optimizer, strategy=strategy ) optimizer.minimize(avg_cost) - exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace()) + exe = paddle.fluid.Executor() exe.run(paddle.fluid.default_startup_program()) import numpy as np @@ -183,6 +189,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "0", } node_b = { @@ -194,9 +201,13 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "1", } def node_func(): + import paddle + + paddle.enable_static() import paddle.distributed.fleet as fleet fleet.init(is_collective=True) @@ -225,7 +236,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ) optimizer.minimize(avg_cost) - exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace()) + exe = paddle.fluid.Executor() exe.run(paddle.fluid.default_startup_program()) proc_a = launch_func(node_func, node_a) @@ -246,6 +257,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "0", } node_b = { @@ -257,9 +269,13 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): ), "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "1", } def node_func(): + import paddle + + paddle.enable_static() import paddle.distributed.fleet as fleet fleet.init(is_collective=True) @@ -289,7 +305,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): optimizer, strategy=strategy ) optimizer.minimize(avg_cost) - exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace()) + exe = paddle.fluid.Executor() exe.run(paddle.fluid.default_startup_program()) import numpy as np diff --git a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py index aab4032afbc..254c8b910bb 100644 --- a/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py +++ b/python/paddle/fluid/tests/unittests/collective/fleet/test_fleet_graph_executor.py @@ -17,9 +17,50 @@ import unittest from launch_function_helper import launch_func -import paddle -import paddle.distributed.fleet as fleet -import paddle.distributed.fleet.base.role_maker as role_maker + +def node_func(): + import paddle + import paddle.distributed.fleet as fleet + import paddle.distributed.fleet.base.role_maker as role_maker + + paddle.enable_static() + + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + input_x = paddle.static.data(name="x", shape=[-1, 32], dtype='float32') + input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') + + fc_1 = paddle.static.nn.fc(x=input_x, size=64, activation='tanh') + fc_2 = paddle.static.nn.fc(x=fc_1, size=64, activation='tanh') + prediction = paddle.static.nn.fc(x=[fc_2], size=2, activation='softmax') + cost = paddle.nn.functional.cross_entropy( + input=prediction, + label=input_y, + reduction='none', + use_softmax=False, + ) + avg_cost = paddle.mean(x=cost) + + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.nccl_comm_num = 2 + strategy.sync_nccl_allreduce = True + optimizer = paddle.optimizer.SGD(learning_rate=0.01) + optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + exe = paddle.fluid.Executor() + exe.run(paddle.fluid.default_startup_program()) + + import numpy as np + + def gen_data(): + return { + "x": np.random.random(size=(128, 32)).astype('float32'), + "y": np.random.randint(2, size=(128, 1)).astype('int64'), + } + + for i in range(5): + cost_val = exe.run(feed=gen_data(), fetch_list=[avg_cost.name]) + print("cost of step[{}] = {}".format(i, cost_val)) class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): @@ -31,6 +72,7 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "0", } node_b = { @@ -40,52 +82,9 @@ class TestFleetGraphExecutionMetaOptimizer(unittest.TestCase): "PADDLE_TRAINER_ENDPOINTS": "127.0.0.1:36001,127.0.0.1:36002", "http_proxy": "", "https_proxy": "", + "FLAGS_selected_gpus": "1", } - def node_func(): - role = role_maker.PaddleCloudRoleMaker(is_collective=True) - fleet.init(role) - input_x = paddle.static.data( - name="x", shape=[-1, 32], dtype='float32' - ) - input_y = paddle.static.data(name="y", shape=[-1, 1], dtype='int64') - - fc_1 = paddle.static.nn.fc(x=input_x, size=64, activation='tanh') - fc_2 = paddle.static.nn.fc(x=fc_1, size=64, activation='tanh') - prediction = paddle.static.nn.fc( - x=[fc_2], size=2, activation='softmax' - ) - cost = paddle.nn.functional.cross_entropy( - input=prediction, - label=input_y, - reduction='none', - use_softmax=False, - ) - avg_cost = paddle.mean(x=cost) - - strategy = paddle.distributed.fleet.DistributedStrategy() - strategy.nccl_comm_num = 2 - strategy.sync_nccl_allreduce = True - optimizer = paddle.optimizer.SGD(learning_rate=0.01) - optimizer = fleet.distributed_optimizer( - optimizer, strategy=strategy - ) - optimizer.minimize(avg_cost) - exe = paddle.fluid.Executor(place=paddle.fluid.CPUPlace()) - exe.run(paddle.fluid.default_startup_program()) - - import numpy as np - - def gen_data(): - return { - "x": np.random.random(size=(128, 32)).astype('float32'), - "y": np.random.randint(2, size=(128, 1)).astype('int64'), - } - - for i in range(5): - cost_val = exe.run(feed=gen_data(), fetch_list=[avg_cost.name]) - print("cost of step[{}] = {}".format(i, cost_val)) - # rank 1 proc_b = launch_func(node_func, node_b) proc_b.start() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_base_3.py b/python/paddle/fluid/tests/unittests/test_fleet_base_3.py index e24beee28e1..f3eff6ee97c 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_base_3.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_base_3.py @@ -87,8 +87,8 @@ class TestFleetBase(unittest.TestCase): meta_list = fleet._get_applied_meta_list() graph_list = fleet._get_applied_graph_list() - self.assertEqual(len(meta_list), 0) - self.assertEqual(len(graph_list), 1) + self.assertEqual(len(meta_list), 1) + self.assertEqual(len(graph_list), 0) if __name__ == "__main__": diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 0ac05f13442..54ac7a4d3db 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -740,6 +740,15 @@ class StaticGraphAdapter: continue uninitialized.append(var_py) + + # for RawProgramOptimizer, it will insert OP with no outputs like: + # c_comm_init(inputs={X=['comm_id_0']} + # but we cannot prune this op. + block = self._startup_prog.global_block() + for op in block.ops: + if op.type == "c_comm_init": + uninitialized.append(op) + if uninitialized: startup_prog = self._startup_prog._prune(uninitialized) self._executor.run(startup_prog) -- GitLab