diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a674b13bde061c253ea0281f4af0c801cf593206..ffb54861904f4a7c6d54ba958343f9cb00cb559b 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4046,7 +4046,7 @@ class PipelineOptimizer(object): """ prev_op = [] for op in ops: - if op.type == 'send_v2' or op.type == 'recv_v2': + if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast': continue if op == cur_op: break @@ -4434,6 +4434,39 @@ class PipelineOptimizer(object): ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 if pair not in self._pipeline_pair: self._pipeline_pair.append(pair) + if self.schedule_mode == 0: # GPipe + block._insert_op( + index=index + extra_index, + type='send_v2', + inputs={'X': var}, + attrs={ + self._op_device_key: prev_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': cur_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='recv_v2', + outputs={'Out': [var]}, + attrs={ + 'out_shape': var.shape, + 'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': prev_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + continue + assert self.schedule_mode == 1 block._insert_op( index=index + extra_index, #type='send_v2', @@ -4452,19 +4485,19 @@ class PipelineOptimizer(object): 'root': 0, }) extra_index += 1 - block._insert_op( - index=index + extra_index, - type='c_sync_comm_stream', - inputs={'X': [var]}, - outputs={'Out': [var]}, - attrs={ - self._op_device_key: cur_device, - self._op_role_key: - core.op_proto_and_checker_maker.OpRole.Backward, - 'ring_id': self.ring_id, - #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, - }) - extra_index += 1 + #block._insert_op( + # index=index + extra_index, + # type='c_sync_comm_stream', + # inputs={'X': [var]}, + # outputs={'Out': [var]}, + # attrs={ + # self._op_device_key: cur_device, + # self._op_role_key: + # core.op_proto_and_checker_maker.OpRole.Backward, + # 'ring_id': self.ring_id, + # #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + # }) + #extra_index += 1 fill_shape = list(var.shape) fill_shape[0] = 1 block._insert_op( @@ -4509,6 +4542,7 @@ class PipelineOptimizer(object): outputs={'Out': [var]}, attrs={ self._op_device_key: cur_device, + #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, self._op_role_key: op_role, 'ring_id': self.ring_id, #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, @@ -4987,6 +5021,10 @@ class PipelineOptimizer(object): and 'local_rank' in main_block.program._pipeline_opt, \ 'Please use pipeline with fleet.' local_rank = main_block.program._pipeline_opt['local_rank'] + schedule_mode = 0 + if 'schedule_mode' in main_block.program._pipeline_opt: + schedule_mode = main_block.program._pipeline_opt['schedule_mode'] + self.schedule_mode = schedule_mode self.use_sharding = False if 'use_sharding' in main_block.program._pipeline_opt: @@ -5074,6 +5112,7 @@ class PipelineOptimizer(object): "inner_parallelism": len(device_list), "num_pipeline_stages": len(device_list), "pipeline_stage": local_rank, + "schedule_mode": schedule_mode, "section_program": program_list[local_rank], "place": place_list[local_rank], "place_id": place_id,