diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 50d97652f38132cd6065dba1896f2d7f1d0ce622..d9031bfda2627f99dbdabc1c600dbaf20ff75e89 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4645,143 +4645,128 @@ class PipelineOptimizer(object): if var_name not in input_var_to_device: input_var_to_device[var_name] = [] - if cur_device in input_var_to_device[var_name]: + if (cur_device, prev_device) in input_var_to_device[var_name]: continue - input_var_to_device[var_name].append(cur_device) - - op_role = op.all_attrs()[self._op_role_key] - var = block.vars[var_name] - prev_device_index = int(prev_device.split(':')[1]) - cur_device_index = int(cur_device.split(':')[1]) - pair = (prev_device_index, cur_device_index) - pair_key = prev_device_index * 1000 + cur_device_index - if cur_device_index > prev_device_index: - ring_id = self.ring_id + cur_device_index - prev_device_index - 1 - else: - ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 - print("call xx_insert, schedule_mode:", self.schedule_mode) - if self.schedule_mode == 0: # GPipe + + device_type = cur_device.split(':')[0] + ':' + + def _insert_send_recv(cur_id, prev_id): + nonlocal extra_index + + cur_dev = device_type + str(cur_id) + prev_dev = device_type + str(prev_id) + if (cur_dev, prev_dev) in input_var_to_device[var_name]: + return + + if cur_id - prev_id > 1: + _insert_send_recv(cur_id - 1, prev_id) + _insert_send_recv(cur_id, cur_id - 1) + input_var_to_device[var_name].append( + (cur_dev, prev_dev)) + return + elif cur_id - prev_id < -1: + _insert_send_recv(cur_id + 1, prev_id) + _insert_send_recv(cur_id, cur_id + 1) + input_var_to_device[var_name].append( + (cur_dev, prev_dev)) + return + + assert abs(cur_id - prev_id) == 1 + + input_var_to_device[var_name].append((cur_dev, prev_dev)) + + op_role = op.all_attrs()[self._op_role_key] + var = block.vars[var_name] + + pair = (prev_id, cur_id) + pair_key = prev_id * 1000 + cur_id + if cur_id > prev_id: + ring_id = self.ring_id + cur_id - prev_id - 1 + else: + ring_id = self.ring_id + 2 + prev_id - cur_id - 1 + + print("call xx_insert, schedule_mode:", self.schedule_mode) + assert self.schedule_mode == 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) + self._pp_ring_map[pair_key] = self.ring_id + ring_id = self.ring_id + self.ring_id += 1 + else: + ring_id = self._pp_ring_map[pair_key] + + print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id)) + + block._insert_op_without_sync( + index=index + extra_index, + type='c_sync_calc_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: prev_dev, + self._op_role_key: op_role, + }) + extra_index += 1 + block._insert_op_without_sync( index=index + extra_index, - type='send_v2', + type="c_broadcast", inputs={'X': var}, + outputs={'Out': var}, attrs={ - self._op_device_key: prev_device, + self._op_device_key: prev_dev, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': cur_device_index, - 'ring_id': self.ring_id - if cur_device_index > prev_device_index else - self.ring_id + 2, + 'use_calc_stream': False, + 'ring_id': ring_id, + 'root': 0, }) extra_index += 1 + + fill_shape = list(var.shape) + fill_shape[0] = 4 block._insert_op_without_sync( index=index + extra_index, - type='recv_v2', + type='fill_constant', + inputs={}, outputs={'Out': [var]}, attrs={ - 'out_shape': var.shape, + 'shape': fill_shape, 'dtype': var.dtype, - self._op_device_key: cur_device, + self._op_device_key: cur_dev, + self._op_role_key: op_role, + 'value': float(0.0), + }) + extra_index += 1 + block._insert_op_without_sync( + index=index + extra_index, + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + self._op_device_key: cur_dev, self._op_role_key: op_role, 'use_calc_stream': True, - 'peer': prev_device_index, - 'ring_id': self.ring_id - if cur_device_index > prev_device_index else - self.ring_id + 2, + 'root': 0, + 'ring_id': ring_id, }) extra_index += 1 - continue - assert self.schedule_mode == 1 - if pair not in self._pipeline_pair: - self._pipeline_pair.append(pair) - self._pp_ring_map[pair_key] = self.ring_id - ring_id = self.ring_id - self.ring_id += 1 - else: - ring_id = self._pp_ring_map[pair_key] - print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id)) - block._insert_op_without_sync( - index=index + extra_index, - #type='send_v2', - type='c_broadcast', - inputs={'X': var}, - outputs={'Out': var}, - attrs={ - self._op_device_key: prev_device, - self._op_role_key: op_role, - 'use_calc_stream': False, - #'peer': cur_device_index, - #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, - 'ring_id': ring_id, - #'ring_id': self.ring_id, - #'root': prev_device_index, - 'root': 0, - }) - extra_index += 1 - #block._insert_op( - # index=index + extra_index, - # type='c_sync_comm_stream', - # inputs={'X': [var]}, - # outputs={'Out': [var]}, - # attrs={ - # self._op_device_key: cur_device, - # self._op_role_key: - # core.op_proto_and_checker_maker.OpRole.Backward, - # 'ring_id': self.ring_id, - # #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, - # }) - #extra_index += 1 - fill_shape = list(var.shape) - fill_shape[0] = 1 - block._insert_op_without_sync( - index=index + extra_index, - #type='recv_v2', - type='fill_constant', - inputs={}, - outputs={'Out': [var]}, - attrs={ - 'shape': fill_shape, - 'dtype': var.dtype, - self._op_device_key: cur_device, - self._op_role_key: op_role, - 'value': float(0.0), - }) - extra_index += 1 - block._insert_op_without_sync( - index=index + extra_index, - #type='recv_v2', - type='c_broadcast', - inputs={'X': var}, - outputs={'Out': var}, - attrs={ - #'out_shape': var.shape, - #'dtype': var.dtype, - self._op_device_key: cur_device, - self._op_role_key: op_role, - 'use_calc_stream': False, - #'peer': prev_device_index, - #'root': prev_device_index, - 'root': 0, - #'ring_id': self.ring_id, - 'ring_id': ring_id, - #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, - #'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2, - }) - extra_index += 1 - block._insert_op_without_sync( - index=index + extra_index, - type='c_sync_comm_stream', - inputs={'X': [var]}, - outputs={'Out': [var]}, - attrs={ - self._op_device_key: cur_device, - #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, - self._op_role_key: op_role, - 'ring_id': ring_id, - #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, - }) - extra_index += 1 + + # block._insert_op_without_sync( + # index=index + extra_index, + # type='c_sync_comm_stream', + # inputs={'X': [var]}, + # outputs={'Out': [var]}, + # attrs={ + # self._op_device_key: cur_dev, + # self._op_role_key: op_role, + # 'ring_id': ring_id, + # }) + # extra_index += 1 + + _insert_send_recv( + int(cur_device.split(':')[1]), + int(prev_device.split(':')[1])) + block._sync_with_cpp() def _clear_gradients(self, main_block, param_names):