diff --git a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py index bd1faf140140226cc05e1ee969d34f0b87602259..8919ded2e245c951016f28851b2ab4b16ce2b48f 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/raw_program_optimizer.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +import os + from paddle import static from paddle.fluid import core from paddle.framework import _global_flags @@ -62,6 +64,9 @@ class RawProgramOptimizer(MetaOptimizerBase): self.calc_comm_same_stream = ( user_defined_strategy._calc_comm_same_stream ) + self.sync_before_allreduce = os.environ.get( + 'FLAGS_sync_before_allreduce', None + ) def _can_apply(self): if not self.role_maker._is_collective: @@ -433,17 +438,28 @@ class RawProgramOptimizer(MetaOptimizerBase): OP_ROLE_KEY: OpRole.Backward, }, ) + if not self.calc_comm_same_stream and self.sync_before_allreduce: + block._insert_op_without_sync( + after_idx + 1, + type='c_sync_calc_stream', + inputs={'X': fused_var}, + outputs={'Out': fused_var}, + attrs={OP_ROLE_KEY: OpRole.Backward}, + ) idx = 0 - if not self.calc_comm_same_stream: + if not self.calc_comm_same_stream and not self.sync_before_allreduce: for i in range(len(grad_param_segments)): - while block.ops[idx].type != 'c_allreduce_sum': + while ( + block.ops[idx].type != 'c_allreduce_sum' + or fused_vars[i].name not in block.ops[idx].input_arg_names + ): idx += 1 grad_segment, param_segment = grad_param_segments[i] for grad in grad_segment: block._insert_op_without_sync( idx + 1, type='depend', - inputs={'X': grad, 'Dep': fused_var}, + inputs={'X': grad, 'Dep': fused_vars[i]}, outputs={'Out': grad}, ) idx += 1 @@ -486,6 +502,21 @@ class RawProgramOptimizer(MetaOptimizerBase): }, ) + if self.calc_comm_same_stream or not self.sync_before_allreduce: + block._sync_with_cpp() + return + + # insert the sync comm op + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + block._insert_op_without_sync( + idx, + type='c_sync_comm_stream', + inputs={'X': fused_vars}, + outputs={'Out': fused_vars}, + attrs={'ring_id': ring_id, OP_ROLE_KEY: OpRole.Backward}, + ) + break block._sync_with_cpp() def __get_ouputs_name_to_idx(self, first_backward_idx, block): diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py index ba826bcf4ff17dca987d3f902f34e37c31a10532..c19791a3c33a8ad53a499a9690c53a1e3ad7b543 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_raw_program_optimizer.py @@ -45,5 +45,22 @@ class TestFleetMetaOptimizerPrecision(TestDistBase): ) +class TestFleetMetaOptimizerPrecisionWithSync(TestFleetMetaOptimizerPrecision): + def need_envs(self): + return {'FLAGS_sync_before_allreduce': '1'} + + def test_dist_train(self): + from paddle import fluid + + if fluid.core.is_compiled_with_cuda(): + self.check_with_place( + "dist_fleet_raw_program_optimizer.py", + delta=1e-5, + check_error_log=True, + log_name=flag_name + 'with_sync', + need_envs=self.need_envs(), + ) + + if __name__ == '__main__': unittest.main()