From e0e0c0fa49a433eb09b7ed5abf68d29da1dadfe3 Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Mon, 21 Jun 2021 17:06:18 +0800 Subject: [PATCH] add sync calc stream and add ut for fuse on gpu (#33580) --- .../framework/distributed_strategy.proto | 1 + .../fleet/base/distributed_strategy.py | 24 ++ .../meta_optimizers/raw_program_optimizer.py | 238 ++++++++---------- .../fluid/tests/unittests/CMakeLists.txt | 1 + ...et_raw_program_optimizer_fuse_allreduce.py | 112 +++++++++ ...et_raw_program_optimizer_fuse_allreduce.py | 45 ++++ 6 files changed, 291 insertions(+), 130 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py create mode 100644 python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index be05941efb5..a63dfd7b091 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -177,6 +177,7 @@ message DistributedStrategy { optional bool tensor_parallel = 29 [ default = false ]; optional bool without_graph_optimization = 30 [ default = false ]; optional int32 fuse_grad_size_in_num = 31 [ default = 1 ]; + optional bool calc_comm_same_stream = 32 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index e44a0e0459d..c4aa9213469 100644 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -853,6 +853,30 @@ class DistributedStrategy(object): "WARNING: without_graph_optimization should have value of bool type" ) + @property + def _calc_comm_same_stream(self): + """ + This based on raw_program_optimizer program + Set whether use same stream for calc and comm when fuse allreduce + The default value for the calc_comm_same_stream is False + Examples: + .. code-block:: python + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.calc_comm_same_stream = True + """ + return self.strategy.calc_comm_same_stream + + @_calc_comm_same_stream.setter + @is_strict_auto + def _calc_comm_same_stream(self, same): + if isinstance(same, bool): + self.strategy.calc_comm_same_stream = same + else: + print( + "WARNING: calc_comm_same_stream should have value of boolean type" + ) + @property def fuse_grad_size_in_num(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index 1333f794cc9..c85242b6a56 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -44,6 +44,7 @@ class RawProgramOptimizer(MetaOptimizerBase): self.fuse_all_reduce_ops = user_defined_strategy.fuse_all_reduce_ops if self.fuse_all_reduce_ops: self.fuse_grad_size_in_num = user_defined_strategy.fuse_grad_size_in_num + self.calc_comm_same_stream = user_defined_strategy._calc_comm_same_stream def _can_apply(self): if not self.role_maker._is_collective: @@ -130,8 +131,7 @@ class RawProgramOptimizer(MetaOptimizerBase): def _transpile_main_program(self, loss): self._insert_loss_grad_ops(loss) - if self.fuse_all_reduce_ops and core.is_compiled_with_npu(): - self._calc_stream = True + if self.fuse_all_reduce_ops: self._allreduce_fusion_program() else: self._insert_allreduce_ops() @@ -206,22 +206,30 @@ class RawProgramOptimizer(MetaOptimizerBase): OP_ROLE_KEY: OpRole.Backward}) break - # TODO(Liu yuang): ADD CUDA allreduce_fusion fuction. - # This function helps reduce the input of allreduce by integrating can save communication time. + # This function helps reduce the number of allreduce by integrating op, which can save communication time. + # to use allreduce fuse, follow these codes: + # strategy = paddle.distributed.fleet.DistributedStrategy() + # strategy.without_graph_optimization = True + # strategy.fuse_all_reduce_ops = True + # strategy.calc_comm_same_stream = False + # strategy.fuse_grad_size_in_num = 8 def _allreduce_fusion_program(self): block = self.main_program.global_block() ring_id = self.global_ring_id record_idx, allreduce_input_vars, allreduce_output_vars = [], [], [] - block_ops = len(list(enumerate(block.ops))) + ops = list(enumerate(block.ops)) - for idx, op in reversed(list(enumerate(block.ops))): + for idx, op in reversed(ops): + # we travers the ops reversely if is_backward_op(op) and \ OP_ROLE_VAR_KEY in op.attr_names: op_role_var = op.attr(OP_ROLE_VAR_KEY) if len(op_role_var) == 0: continue - assert len(op_role_var) % 2 == 0 + assert len(op_role_var) % 2 == 0, "vars need to be one param var followed by one grad var, " \ + "but got odd number of vars" for i in range(0, len(op_role_var), 2): + # handle vars in each op, each time handle a param and a grad param_name = op_role_var[i] param = block.var(param_name) grad_name = op_role_var[i + 1] @@ -229,6 +237,7 @@ class RawProgramOptimizer(MetaOptimizerBase): if param.is_distributed: continue if ".cast_fp16@GRAD" in grad_name: + # when amp=True get the fp16 param param_name = param_name + ".cast_fp16" if not block.has_var(param_name): raise ValueError("op cast name error {}".format( @@ -236,154 +245,102 @@ class RawProgramOptimizer(MetaOptimizerBase): else: param = block.var(param_name) - if len(allreduce_output_vars) == 0: - allreduce_output_vars.append([grad]) - allreduce_input_vars.append([param]) - if self.fuse_grad_size_in_num == 1: - record_idx.append([idx, idx]) - continue - record_idx.append([-2, idx]) - elif len(allreduce_output_vars[ - -1]) == self.fuse_grad_size_in_num: + if len(allreduce_output_vars) == 0 or \ + len(allreduce_output_vars[-1]) == \ + self.fuse_grad_size_in_num: + # start of the fusion or last group meets the config size allreduce_output_vars.append([grad]) allreduce_input_vars.append([param]) - if self.fuse_grad_size_in_num == 1: - record_idx.append([idx, idx]) - continue - if idx != block_ops - 1: - record_idx.append([-2, idx]) + # add the start and end idx to the record idx + record_idx.append([idx, idx]) else: + # Current group's size is below the config size + # append grad and param to the last group (current group) + # update the start idx to current op's idx + # Since we travers the ops reversely, the idx is descending + # we update the first entry of each entry for record_idx allreduce_output_vars[-1].append(grad) allreduce_input_vars[-1].append(param) record_idx[-1][0] = idx - if record_idx[-1][0] == -2: - record_idx[-1][0] = record_idx[-1][1] - assert len(allreduce_output_vars) == len( record_idx ), "It has different lens between the allreduce_output_vars and record_idx." if not allreduce_output_vars or not allreduce_input_vars: + # nothing needs to be allreduced return self.vars = collections.OrderedDict() - index, offset_pos, pos, offset = 0, 0, 0, 0 + index, pos, offset = 0, 0, 0 start, end = record_idx[index] - men_list = [end, start] - - # Here we need to explain the flag. When integrating OP, we will encounter different groups of the same Op. - # Because we insert coalesce tensor in reverse ops, - # we need to use flag to record whether the current OP has been inserted into coalesce tensor。 - # For example: - # [(3, 2), (2, 2), (1, 0)], (3, 2), (2, 2) using same op, but in different groups. - - for idx, op in reversed(list(enumerate(block.ops))): + for idx, op in reversed(ops): if idx == start: pos = 0 - flag = True if end == men_list[-1] else False - offset = offset_pos if flag else 0 done_output_vars, done_input_vars = self._split_fuction( - allreduce_output_vars[index], allreduce_input_vars[index]) + allreduce_output_vars[index], # grad + allreduce_input_vars[index] # param + ) for id_, done_output_var in enumerate(done_output_vars): - if flag: - tmp_var = block.create_var( - name=unique_name.generate( - 'FusedOutput_{}_{}'.format(start, id_ + - offset)), - dtype=done_output_var[0].dtype, - persistable=False, - stop_gradient=True) - self.vars['FusedOutput_{}_{}'.format(start, id_ + - offset)] = tmp_var + tmp_var = block.create_var( + name=unique_name.generate('FusedOutput_{}'.format( + done_output_var[0].name)), + dtype=done_output_var[0].dtype, + persistable=False, + stop_gradient=True) + self.vars['FusedOutput_{}'.format(done_output_var[0] + .name)] = tmp_var - block._insert_op( - idx + id_ + offset, - type="coalesce_tensor", - inputs={"Input": done_input_vars[id_]}, - outputs={ - "Output": done_output_var, - "FusedOutput": tmp_var - }, - attrs={ - "copy_data": False, - "use_align": True, - "dtype": done_output_var[0].dtype - }) - pos += 1 - else: - tmp_var = block.create_var( - name=unique_name.generate( - 'FusedOutput_{}_{}'.format(start, id_)), - dtype=done_output_var[0].dtype, - persistable=False, - stop_gradient=True) - self.vars['FusedOutput_{}_{}'.format(start, - id_)] = tmp_var + block._insert_op( + idx + id_, + type="coalesce_tensor", + inputs={"Input": done_input_vars[id_]}, + outputs={ + "Output": done_output_var, + "FusedOutput": tmp_var + }, + attrs={ + "copy_data": False, + "use_align": True, + "dtype": done_output_var[0].dtype, + OP_ROLE_KEY: OpRole.Backward + }) + pos += 1 - block._insert_op( - idx + id_, - type="coalesce_tensor", - inputs={"Input": done_input_vars[id_]}, - outputs={ - "Output": done_output_var, - "FusedOutput": tmp_var - }, - attrs={ - "copy_data": False, - "use_align": True, - "dtype": done_output_var[0].dtype - }) - pos += 1 - offset_pos = pos - - # TODO(Liu yuang): ADD CUDA and NPU's EVENT and c_allreduce_sum. for id_ in range(len(done_output_vars)): - if flag: - block._insert_op( - end + id_ + pos + 1, - type='c_allreduce_sum', - inputs={ - 'X': self.vars['FusedOutput_{}_{}'.format( - start, id_ + offset)] - }, - outputs={ - 'Out': self.vars['FusedOutput_{}_{}'.format( - start, id_ + offset)] - }, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True - if self._calc_stream else False, - OP_ROLE_KEY: OpRole.Backward - }) - else: + x = self.vars['FusedOutput_{}'.format(done_output_vars[id_][ + 0].name)] + out = x + + # NOTE: there still some optimize space if use EVENT instead of sync + if not self.calc_comm_same_stream: + # need sync if the calc and comm stream are not the same block._insert_op( end + id_ + pos + 1, - type='c_allreduce_sum', - inputs={ - 'X': self.vars['FusedOutput_{}_{}'.format(start, - id_)] - }, - outputs={ - 'Out': self.vars['FusedOutput_{}_{}'.format( - start, id_)] - }, - attrs={ - 'ring_id': ring_id, - 'use_calc_stream': True - if self._calc_stream else False, - OP_ROLE_KEY: OpRole.Backward - }) + type='c_sync_calc_stream', + inputs={'X': x}, + outputs={'Out': out}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + + block._insert_op( + end + id_ + pos + 1 + if self.calc_comm_same_stream else end + id_ + pos + 2, + type='c_allreduce_sum', + inputs={'X': x}, + outputs={'Out': out}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': self.calc_comm_same_stream, + OP_ROLE_KEY: OpRole.Backward + }) + index += 1 - men_list.append(end) - men_list.append(start) if len(record_idx) == index: - start = end = -1 - continue + break start, end = record_idx[index] - if not self._calc_stream: + if not self.calc_comm_same_stream: + # need sync if the calc and comm stream are not the same for idx, op in enumerate(block.ops): if is_optimizer_op(op): block._insert_op( @@ -397,34 +354,50 @@ class RawProgramOptimizer(MetaOptimizerBase): }) break - # Integrate grads of the same type to form a combination. If skip_comb is selected, will return grads of the same group. + # Integrate grads of the same type to form a combination. + # If combination is selected, will return grads of the same type in a groups. # For example:[(fp16, fp16), (fp32), (fp16)] -> [(fp16, fp16, fp16), (fp32)] def _split_fuction(self, allreduce_output_vars, allreduce_input_vars, - skip_comb=True): + combination=True): input_vars, final_input_vars, output_vars, final_output_vars = [], [], [], [] - if len(allreduce_output_vars) - 1 == 0: + if len(allreduce_output_vars) == 1: + # only have one var to handle final_output_vars.append(allreduce_output_vars) final_input_vars.append(allreduce_input_vars) return final_output_vars, final_input_vars for idx in range(len(allreduce_input_vars) - 1): + # the last var needs to be handled differently if allreduce_input_vars[idx].dtype == allreduce_input_vars[idx + 1].dtype: + # if current var and next var are in same type + # append current var to input_vars input_vars.append(allreduce_input_vars[idx]) if idx == len(allreduce_input_vars) - 2: + # if current var is the second last var + # append the last var to input_vars + # and update the final_input_vars input_vars.append(allreduce_input_vars[idx + 1]) final_input_vars.append(input_vars) else: + # the current var and next var are in different types + # append current var to input_vars + # update the final_input_vars + # reset input_vars to receive a new type input_vars.append(allreduce_input_vars[idx]) final_input_vars.append(input_vars) input_vars = [] if idx == len(allreduce_input_vars) - 2: + # if current var is the second last var + # append the last var to a reset input_vars since they are in different types + # and update the final_input_vars input_vars.append(allreduce_input_vars[idx + 1]) final_input_vars.append(input_vars) for idx in range(len(allreduce_output_vars) - 1): + # the procedure for the output vars is the same with that for the input vars if allreduce_output_vars[idx].dtype == allreduce_output_vars[ idx + 1].dtype: output_vars.append(allreduce_output_vars[idx]) @@ -438,10 +411,14 @@ class RawProgramOptimizer(MetaOptimizerBase): if idx == len(allreduce_output_vars) - 2: output_vars.append(allreduce_output_vars[idx + 1]) final_output_vars.append(output_vars) - if skip_comb: + + # at this time, all vars in each group in final_input_vars and final_output_vars are in the same type + + if combination: input_fp16_vars, input_fp32_vars, output_fp16_vars, output_fp32_vars = [], [], [], [] for final_input_var in final_input_vars: if final_input_var[0].dtype == core.VarDesc.VarType.FP16: + # extend the group input_fp16_vars.extend(final_input_var) else: input_fp32_vars.extend(final_input_var) @@ -451,6 +428,7 @@ class RawProgramOptimizer(MetaOptimizerBase): output_fp16_vars.extend(final_output_var) else: output_fp32_vars.extend(final_output_var) + final_output_vars, final_input_vars = [], [] if output_fp16_vars: final_output_vars.append(output_fp16_vars) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 03aaf7ed03e..023b092b774 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -718,6 +718,7 @@ if (WITH_DISTRIBUTE) set_tests_properties(test_dist_fleet_sparse_embedding_ctr PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_infer PROPERTIES TIMEOUT 200) set_tests_properties(test_dist_fleet_raw_program_optimizer PROPERTIES TIMEOUT 120) + set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce PROPERTIES TIMEOUT 60) endif() if (WITH_DISTRIBUTE AND NOT APPLE) diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py b/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py new file mode 100644 index 00000000000..aaf33d04e6b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/dist_fleet_raw_program_optimizer_fuse_allreduce.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_dist_base import TestDistRunnerBase, runtime_main +import unittest +import paddle +import os +import paddle.distributed.fleet as fleet +import paddle.distributed.fleet.base.role_maker as role_maker +import numpy as np +from functools import reduce +import paddle.fluid as fluid + +paddle.enable_static() + +DTYPE = "float32" +paddle.dataset.mnist.fetch() + +# Fix seed for test +fluid.default_startup_program().random_seed = 1 +fluid.default_main_program().random_seed = 1 + + +def cnn_model(data): + conv_pool_1 = fluid.nets.simple_img_conv_pool( + input=data, + filter_size=5, + num_filters=20, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + conv_pool_2 = fluid.nets.simple_img_conv_pool( + input=conv_pool_1, + filter_size=5, + num_filters=50, + pool_size=2, + pool_stride=2, + act="relu", + param_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant( + value=0.01))) + + SIZE = 10 + input_shape = conv_pool_2.shape + param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE] + scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5 + + predict = fluid.layers.fc( + input=conv_pool_2, + size=SIZE, + act="softmax", + param_attr=fluid.param_attr.ParamAttr( + initializer=fluid.initializer.Constant(value=0.01))) + return predict + + +class TestFleetMetaOptimizerFuseAllReducePrecision(TestDistRunnerBase): + def get_model(self, batch_size=2, single_device=False): + # Input data + images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype=DTYPE) + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + + # Train program + predict = cnn_model(images) + cost = fluid.layers.cross_entropy(input=predict, label=label) + avg_cost = fluid.layers.mean(x=cost) + + # Evaluator + batch_size_tensor = fluid.layers.create_tensor(dtype='int64') + batch_acc = fluid.layers.accuracy( + input=predict, label=label, total=batch_size_tensor) + + test_program = fluid.default_main_program().clone(for_test=True) + + # Reader + train_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + test_reader = paddle.batch( + paddle.dataset.mnist.test(), batch_size=batch_size) + + optimizer = paddle.fluid.optimizer.Adam(0.01) + if single_device: + optimizer.minimize(avg_cost) + else: + role = role_maker.PaddleCloudRoleMaker(is_collective=True) + fleet.init(role) + strategy = paddle.distributed.fleet.DistributedStrategy() + strategy.without_graph_optimization = True + strategy.fuse_all_reduce_ops = True + strategy._calc_comm_same_stream = False + strategy.fuse_grad_size_in_num = 8 + optimizer = fleet.distributed_optimizer( + optimizer, strategy=strategy) + optimizer.minimize(avg_cost) + + return test_program, avg_cost, train_reader, test_reader, batch_acc, predict + + +if __name__ == "__main__": + runtime_main(TestFleetMetaOptimizerFuseAllReducePrecision) diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py new file mode 100644 index 00000000000..21b921c52c8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer_fuse_allreduce.py @@ -0,0 +1,45 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from test_dist_base import TestDistBase +import paddle +import os + +paddle.enable_static() +flag_name = os.path.splitext(__file__)[0] + + +class TestFleetMetaOptimizerAllReduceFusePrecision(TestDistBase): + def _setup_config(self): + self._sync_mode = True + self._use_reduce = False + self._use_reader_alloc = False + self._nccl2_mode = True + self._nccl2_reduce_layer = True + self._use_fleet_api = True + self._use_fleet_api_20 = True + + def test_dist_train(self): + import paddle.fluid as fluid + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "dist_fleet_raw_program_optimizer_fuse_allreduce.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name) + + +if __name__ == '__main__': + unittest.main() -- GitLab