未验证 提交 5f65ff91 编写于 作者: W WangXi 提交者: GitHub

[hybrid performance] Optimize pipeline send wait (#34086)

上级 9cda0596
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class NopOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
class NopOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("X", "(Tensor) The input tensor of nop op.").AsDuplicable();
AddOutput("Out", "(Tensor) The output tensor of nop op.").AsDuplicable();
AddComment(R"DOC(
Nop Operator
Do nothing, except let the input and output tensors occupy the memory and
establish the dependency between input and output tensors.
)DOC");
}
};
template <typename T>
class NopKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker);
REGISTER_OP_CPU_KERNEL(nop, ops::NopKernel<float>);
REGISTER_OP_CUDA_KERNEL(nop, ops::NopKernel<float>);
REGISTER_OP_NPU_KERNEL(nop, ops::NopKernel<float>);
......@@ -4221,6 +4221,8 @@ class PipelineOptimizer(object):
self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
self.output_var_to_op = None
self.input_var_to_op = None
# insert allreduce op to sync global information for global
# gradient clip and amp
......@@ -4657,6 +4659,9 @@ class PipelineOptimizer(object):
int(self._op_role.Optimize),
int(self._op_role.Backward) | int(self._op_role.Loss),
]
pre_stage_id = None
decrease_flag = False
in_optimize = False
for op in block.ops:
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
......@@ -4666,11 +4671,15 @@ class PipelineOptimizer(object):
assert op.has_attr(self._op_role_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_role_key))
assert int(op.attr(self._op_role_key)) in valid_op_role_value, \
op_role = op.attr(self._op_role_key)
assert int(op_role) in valid_op_role_value, \
"op_role {} for op {} must be one of {}".format(
op.attr(self._op_role_key),
op_role,
op.type,
valid_op_role_value)
if int(op_role) == int(self._op_role.Optimize):
in_optimize = True
assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_device_key))
......@@ -4678,13 +4687,33 @@ class PipelineOptimizer(object):
device = op.attr(self._op_device_key)
assert device, ("op_device attribute for op "
"{} has not been set.".format(op.type))
if device == "gpu:all": continue
if device == "gpu:all" or device == "npu:all": continue
dev_type = device.split(':')[0]
stage_id = int(device.split(':')[1])
assert dev_type == "gpu" or dev_type == 'npu', (
"Now only gpu and npu devices are supported "
"for pipeline parallelism.")
if not device in device_list:
if device not in device_list:
device_list.append(device)
if not in_optimize:
if pre_stage_id is not None:
interval = stage_id - pre_stage_id
assert abs(interval) <= 1, \
"The stage interval of two consecutive ops in the pipeline must be < = 1," \
"but the interval of op={} and prev op is {}".format(op, interval)
# stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0)
# if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error
if interval == -1:
decrease_flag = True
if interval == 1:
assert decrease_flag is False, \
"Pipeline stage must be in order, " \
"please check the stage of op={}".format(op)
pre_stage_id = stage_id
return device_list
def _insert_sendrecv_ops_for_boundaries(self, block):
......@@ -4826,6 +4855,7 @@ class PipelineOptimizer(object):
})
extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
......@@ -4833,7 +4863,8 @@ class PipelineOptimizer(object):
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op_without_sync(
sync_comm_op = block._insert_op_without_sync(
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
inputs={'X': [var]},
......@@ -4843,8 +4874,11 @@ class PipelineOptimizer(object):
self._op_role_key: new_op_role,
'ring_id': ring_id,
})
if int(op_role) == int(self._op_role.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
extra_index_info['index'] += 1
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
......@@ -5153,17 +5187,55 @@ class PipelineOptimizer(object):
Get info of op input and output.
'''
# A map from output var to op which generate it.
self.output_var_to_op = dict()
output_var_to_op = defaultdict(list)
# A map from var to op which takes it as input.
self.input_var_to_op = dict()
input_var_to_op = defaultdict(list)
for index, op in enumerate(list(block.ops)):
for index, op in enumerate(block.ops):
for var_name in op.input_arg_names:
ops = self.input_var_to_op.setdefault(var_name, [])
ops.append([op, index])
input_var_to_op[var_name].append([op, index])
for var_name in op.output_arg_names:
ops = self.output_var_to_op.setdefault(var_name, [])
ops.append([op, index])
output_var_to_op[var_name].append([op, index])
return output_var_to_op, input_var_to_op
def _optimize_forward_send_sync(self, program):
"""
optimize forward send's sync_comm_stream schedule
"""
if self.schedule_mode != '1F1B': return
block = program.block(0)
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == 'recv_v2' and self._is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None: return
offset = 0
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index: break
if op.type == 'c_sync_comm_stream' and op.has_attr('pipeline_flag'):
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
# NOTE:
# 1. When the backward recv is completed, it indicates
# that the forward send is completed too. So we only need
# to use the NOP op to prevent memory release.
# 2. Because we removed sync_comm_op,
# we will insert NOP after recv_op.
block._insert_op_without_sync(
index=backward_recv_index,
type='nop',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={self._op_role_key: self._op_role.Backward})
block._sync_with_cpp()
def minimize(self,
loss,
......@@ -5200,7 +5272,8 @@ class PipelineOptimizer(object):
loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._origin_optimizer._param_device_map
self._get_input_output_info(main_block)
self.output_var_to_op, self.input_var_to_op = \
self._get_input_output_info(main_block)
# Step1: add default op_device attribute for ops.
self._add_op_device_attr(main_block)
device_list = self._check_validation(main_block)
......@@ -5229,6 +5302,10 @@ class PipelineOptimizer(object):
for p in program_list:
self._create_vars(p.global_block(), main_block)
self.local_rank %= len(device_list)
# Step3.5: optimize forward send sync_comm to overlap send and recv
self._optimize_forward_send_sync(program_list[self.local_rank])
# Step4: Special Case: process persistable vars that exist in
# multiple sections
# FIXME
......@@ -5238,7 +5315,6 @@ class PipelineOptimizer(object):
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
self.local_rank %= len(device_list)
place_list = []
for dev in device_list:
dev_index = int(dev.split(":")[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册