From 391b210f59d916c921ddc1cb21b8d38ee61b5335 Mon Sep 17 00:00:00 2001 From: qizhaoaoe <10208099+qizhaoaoe@users.noreply.github.com> Date: Fri, 17 Mar 2023 23:20:16 +0800 Subject: [PATCH] remove fluid.layers.collective and related unittests. (#51372) * remove fluid.layers.collective and related unittests. * fix layers.init import * fix all_gather method * remove unused args in all_gather method --- python/paddle/fluid/layers/__init__.py | 3 +- python/paddle/fluid/layers/collective.py | 203 ------------------ .../fluid/tests/unittests/CMakeLists.txt | 2 - .../unittests/collective_reducescatter.py | 41 ---- .../tests/unittests/test_reducescatter_api.py | 42 ---- python/paddle/hapi/model.py | 20 +- tools/parallel_UT_rule.py | 2 - 7 files changed, 12 insertions(+), 301 deletions(-) delete mode 100644 python/paddle/fluid/layers/collective.py delete mode 100644 python/paddle/fluid/tests/unittests/collective_reducescatter.py delete mode 100644 python/paddle/fluid/tests/unittests/test_reducescatter_api.py diff --git a/python/paddle/fluid/layers/__init__.py b/python/paddle/fluid/layers/__init__.py index 40247bde333..6ab688cef37 100644 --- a/python/paddle/fluid/layers/__init__.py +++ b/python/paddle/fluid/layers/__init__.py @@ -23,7 +23,8 @@ from .control_flow import * from . import math_op_patch from .math_op_patch import * from .learning_rate_scheduler import * -from .collective import * +from ..layer_helper import LayerHelper + __all__ = [] __all__ += nn.__all__ diff --git a/python/paddle/fluid/layers/collective.py b/python/paddle/fluid/layers/collective.py deleted file mode 100644 index db451bcce98..00000000000 --- a/python/paddle/fluid/layers/collective.py +++ /dev/null @@ -1,203 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..layer_helper import LayerHelper, unique_name -from ..framework import Variable, in_dygraph_mode -import paddle -from paddle import _C_ops, _legacy_C_ops - - -def _allreduce(x, out=None, reduce_type="sum", sync_mode=False): - helper = LayerHelper("allreduce", **locals()) - # Convert string reduce type to op int type - red_typ_int = 0 - if reduce_type == "sum": - red_typ_int = 0 - elif reduce_type == "prod": - red_typ_int = 1 - elif reduce_type == "max": - red_typ_int = 2 - elif reduce_type == "min": - red_typ_int = 3 - else: - raise TypeError("reduce type can only be [sum|prod|max|min]") - - if out is None: - out = helper.create_variable( - name=unique_name.generate_with_ignorable_key( - ".".join([x.name, 'tmp']) - ), - shape=x.shape, - dtype=x.dtype, - type=x.type, - persistable=x.persistable, - stop_gradient=True, - ) - helper.append_op( - type='allreduce', - inputs={'X': [x]}, - outputs={'Out': [out]}, - attrs={"reduce_type": red_typ_int, "sync_mode": sync_mode}, - ) - return out - - -def _broadcast(x, root, sync_mode=False): - helper = LayerHelper("broadcast", **locals()) - helper.append_op( - type='broadcast', - inputs={'X': [x]}, - outputs={'Out': [x]}, - attrs={"sync_mode": sync_mode, "root": root}, - ) - return x - - -def _c_allreduce( - x, out=None, reduce_type='sum', ring_id=0, use_calc_stream=False -): - helper = LayerHelper('c_allreduce', **locals()) - - if reduce_type not in ['sum', 'prob', 'max', 'min']: - raise TypeError('reduce type can only be "sum|prod|max|min]"') - - op_type = 'c_allreduce_' + reduce_type - if out is None: - out = helper.create_variable( - name=unique_name.generate_with_ignorable_key( - '.'.join([x.name, op_type]) - ), - shape=x.shape, - dtype=x.dtype, - type=x.type, - persistable=x.persistable, - ) - - helper.append_op( - type=op_type, - inputs={'X': [x]}, - outputs={'Out': [out]}, - attrs={'ring_id': ring_id, 'use_calc_stream': use_calc_stream}, - ) - return out - - -def _c_broadcast(x, root=0, ring_id=0, use_calc_stream=False): - op_type = 'c_broadcast' - helper = LayerHelper(op_type, **locals()) - helper.append_op( - type=op_type, - inputs={'X': [x]}, - outputs={'Out': [x]}, - attrs={ - 'root': root, - 'ring_id': ring_id, - 'use_calc_stream': use_calc_stream, - }, - ) - return x - - -def _c_allgather(x, nranks, ring_id=0, use_calc_stream=False): - op_type = 'c_allgather' - - if in_dygraph_mode(): - group = paddle.distributed.collective._get_default_group() - tensor_shape = list(x.shape) - tensor_shape[0] *= nranks - out = paddle.empty(tensor_shape, x.dtype) - task = group.process_group.all_gather(x, out) - task.wait() - return out - else: - helper = LayerHelper(op_type, **locals()) - out_shape = list(x.shape[:]) - if out_shape[0] > 0: - out_shape[0] *= nranks - out = helper.create_variable( - name=unique_name.generate_with_ignorable_key( - '.'.join([x.name, op_type]) - ), - shape=out_shape, - dtype=x.dtype, - type=x.type, - persistable=x.persistable, - ) - helper.append_op( - type=op_type, - inputs={'X': [x]}, - outputs={'Out': [out]}, - attrs={ - 'nranks': nranks, - 'ring_id': ring_id, - 'use_calc_stream': use_calc_stream, - }, - ) - return out - - -def _c_reducescatter(x, nranks, ring_id=0, use_calc_stream=False): - if not isinstance(x, Variable): - raise TypeError('x must be a Variable') - - if x.shape[0] > 0 and x.shape[0] % nranks != 0: - raise ValueError( - 'x.shape[0](%d) cannot be evenly divided by nranks(%d)' - % (x.shape[0], nranks) - ) - - op_type = 'c_reducescatter' - helper = LayerHelper(op_type, **locals()) - out_shape = list(x.shape[:]) - if out_shape[0] > 0: - out_shape[0] //= nranks - out = helper.create_variable( - name=unique_name.generate_with_ignorable_key( - '.'.join([x.name, op_type]) - ), - shape=out_shape, - dtype=x.dtype, - type=x.type, - persistable=x.persistable, - ) - helper.append_op( - type=op_type, - inputs={'X': [x]}, - outputs={'Out': [out]}, - attrs={ - 'nranks': nranks, - 'ring_id': ring_id, - 'use_calc_stream': use_calc_stream, - }, - ) - return out - - -def _c_sync_calc_stream(x): - op_type = 'c_sync_calc_stream' - helper = LayerHelper(op_type, **locals()) - helper.append_op(type=op_type, inputs={'X': [x]}, outputs={'Out': [x]}) - return x - - -def _c_sync_comm_stream(x, ring_id): - op_type = 'c_sync_comm_stream' - helper = LayerHelper(op_type, **locals()) - helper.append_op( - type=op_type, - inputs={'X': [x]}, - outputs={'Out': [x]}, - attrs={'ring_id': ring_id}, - ) - return x diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 718497311d1..6bde7643a98 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -207,7 +207,6 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) list(REMOVE_ITEM TEST_OPS test_boxps) list(REMOVE_ITEM TEST_OPS test_allgather) list(REMOVE_ITEM TEST_OPS test_reducescatter) - list(REMOVE_ITEM TEST_OPS test_reducescatter_api) endif() list(REMOVE_ITEM TEST_OPS test_seq_concat_op) # FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290 @@ -1142,7 +1141,6 @@ if((WITH_ROCM OR WITH_GPU) AND NOT WIN32) PROPERTIES TIMEOUT 120) set_tests_properties(test_pipeline_parallel PROPERTIES LABELS "RUN_TYPE=DIST") - set_tests_properties(test_reducescatter_api PROPERTIES TIMEOUT 120) set_tests_properties(test_reducescatter PROPERTIES TIMEOUT 120) set_tests_properties(test_allgather PROPERTIES TIMEOUT 120) endif() diff --git a/python/paddle/fluid/tests/unittests/collective_reducescatter.py b/python/paddle/fluid/tests/unittests/collective_reducescatter.py deleted file mode 100644 index 9813553295f..00000000000 --- a/python/paddle/fluid/tests/unittests/collective_reducescatter.py +++ /dev/null @@ -1,41 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from test_collective_base import TestCollectiveRunnerBase, runtime_main - -import paddle -import paddle.fluid as fluid - -paddle.enable_static() - - -class TestCollectiveReduceScatter(TestCollectiveRunnerBase): - def __init__(self): - self.global_ring_id = 0 - - def get_model(self, main_prog, startup_program): - ring_id = 0 - nranks = 2 - with fluid.program_guard(main_prog, startup_program): - tindata = paddle.static.data( - name="tindata", shape=[-1, 10, 1000], dtype='float32' - ) - tindata.desc.set_need_check_feed(False) - toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) - toutdata = fluid.layers.collective._c_sync_comm_stream(toutdata, 0) - return toutdata - - -if __name__ == "__main__": - runtime_main(TestCollectiveReduceScatter, "reducescatter", 0) diff --git a/python/paddle/fluid/tests/unittests/test_reducescatter_api.py b/python/paddle/fluid/tests/unittests/test_reducescatter_api.py deleted file mode 100644 index 8153f2f81c0..00000000000 --- a/python/paddle/fluid/tests/unittests/test_reducescatter_api.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from test_collective_base import TestDistBase - -import paddle -import paddle.fluid as fluid - -paddle.enable_static() - - -class TestReduceScatterAPI(TestDistBase): - def _setup_config(self): - pass - - def test_reducescatter(self, col_type="reduce_scatter"): - self.check_with_place("collective_reducescatter.py", col_type) - - def test_reducescatter_with_error(self): - nranks = 2 - tindata = fluid.data(name="tindata", shape=[5, 1000], dtype='float32') - try: - toutdata = fluid.layers.collective._c_reducescatter(tindata, nranks) - except ValueError: - pass - - -if __name__ == '__main__': - unittest.main() diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 54ac7a4d3db..43ac3de160c 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -34,7 +34,6 @@ from paddle.fluid.executor import global_scope from paddle.fluid.framework import Variable from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.framework import _get_paddle_place, _non_static_mode -from paddle.fluid.layers import collective from paddle.framework.io_utils import is_belong_to_optimizer from paddle.io import DataLoader, Dataset, DistributedBatchSampler from paddle.jit.translated_layer import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX @@ -91,10 +90,11 @@ def extract_args(func): return inspect.getfullargspec(func).args -def _all_gather(x, nranks, ring_id=0, use_calc_stream=True): - return collective._c_allgather( - x, nranks, ring_id=ring_id, use_calc_stream=use_calc_stream - ) +def _all_gather(x): + output = [] + dist.all_gather(output, x) + output = paddle.concat(output, axis=0) + return output def wait_server_ready(endpoints): @@ -658,9 +658,9 @@ class StaticGraphAdapter: losses = self.model._loss(*(outputs + labels)) if self._nranks > 1 and mode != 'train': - outputs = [_all_gather(o, self._nranks) for o in outputs] + outputs = [_all_gather(o) for o in outputs] if mode != 'test': - labels = [_all_gather(l, self._nranks) for l in labels] + labels = [_all_gather(l) for l in labels] if mode != 'test': for metric in self.model._metrics: @@ -885,8 +885,8 @@ class DynamicGraphAdapter: losses = to_list(losses) if self._nranks > 1: - outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] - labels = [_all_gather(l, self._nranks) for l in labels] + outputs = [_all_gather(o) for o in to_list(outputs)] + labels = [_all_gather(l) for l in labels] metrics = [] for metric in self.model._metrics: # cut off padding value. @@ -931,7 +931,7 @@ class DynamicGraphAdapter: self._input_info = _update_input_info(inputs) outputs = self.model.network(*inputs) if self._nranks > 1 and isinstance(self.model._place, fluid.CUDAPlace): - outputs = [_all_gather(o, self._nranks) for o in to_list(outputs)] + outputs = [_all_gather(o) for o in to_list(outputs)] return [to_numpy(o) for o in to_list(outputs)] diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index c3f3fb15d01..eb65983e62a 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -1646,7 +1646,6 @@ SIXTH_PARALLEL_JOB_NEW = [ 'test_collective_wait', 'test_nn_matmul_v2_grad', 'test_quant2_int8_resnet50_mkldnn', - 'test_reducescatter_api', 'test_collective_sendrecv', 'test_collective_scatter', 'test_gru_op', @@ -1726,7 +1725,6 @@ CPU_PARALLEL_JOB = [ 'test_requantize_mkldnn_op', 'test_repeated_fc_relu_fuse_pass', 'test_registry', - 'test_reducescatter_api', 'test_reducescatter', 'test_recurrent_op', 'test_recommender_system', -- GitLab