diff --git a/paddle/fluid/operators/nop_op.cc b/paddle/fluid/operators/nop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..876468f8a7eacaf931e4a76ca0f78f18a4279207 --- /dev/null +++ b/paddle/fluid/operators/nop_op.cc @@ -0,0 +1,66 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class NopOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override {} + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +class NopOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddInput("X", "(Tensor) The input tensor of nop op.").AsDuplicable(); + AddOutput("Out", "(Tensor) The output tensor of nop op.").AsDuplicable(); + AddComment(R"DOC( +Nop Operator + +Do nothing, except let the input and output tensors occupy the memory and +establish the dependency between input and output tensors. +)DOC"); + } +}; + +template +class NopKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_WITHOUT_GRADIENT(nop, ops::NopOp, ops::NopOpMaker); + +REGISTER_OP_CPU_KERNEL(nop, ops::NopKernel); + +REGISTER_OP_CUDA_KERNEL(nop, ops::NopKernel); + +REGISTER_OP_NPU_KERNEL(nop, ops::NopKernel); diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 537e320a46134e8de5def10960b6caca35e2b692..2a777d2ab8148453e973b78d33b6417461d448d6 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4221,6 +4221,8 @@ class PipelineOptimizer(object): self._param_device_map = None self._pipeline_pair = [] self._pp_ring_map = dict() + self.output_var_to_op = None + self.input_var_to_op = None # insert allreduce op to sync global information for global # gradient clip and amp @@ -4657,6 +4659,9 @@ class PipelineOptimizer(object): int(self._op_role.Optimize), int(self._op_role.Backward) | int(self._op_role.Loss), ] + pre_stage_id = None + decrease_flag = False + in_optimize = False for op in block.ops: if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( @@ -4666,11 +4671,15 @@ class PipelineOptimizer(object): assert op.has_attr(self._op_role_key), ( "op ({}) has no {} attribute.".format(op.type, self._op_role_key)) - assert int(op.attr(self._op_role_key)) in valid_op_role_value, \ + op_role = op.attr(self._op_role_key) + assert int(op_role) in valid_op_role_value, \ "op_role {} for op {} must be one of {}".format( - op.attr(self._op_role_key), + op_role, op.type, valid_op_role_value) + if int(op_role) == int(self._op_role.Optimize): + in_optimize = True + assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, self._op_device_key)) @@ -4678,13 +4687,33 @@ class PipelineOptimizer(object): device = op.attr(self._op_device_key) assert device, ("op_device attribute for op " "{} has not been set.".format(op.type)) - if device == "gpu:all": continue + if device == "gpu:all" or device == "npu:all": continue + dev_type = device.split(':')[0] + stage_id = int(device.split(':')[1]) assert dev_type == "gpu" or dev_type == 'npu', ( "Now only gpu and npu devices are supported " "for pipeline parallelism.") - if not device in device_list: + + if device not in device_list: device_list.append(device) + + if not in_optimize: + if pre_stage_id is not None: + interval = stage_id - pre_stage_id + assert abs(interval) <= 1, \ + "The stage interval of two consecutive ops in the pipeline must be < = 1," \ + "but the interval of op={} and prev op is {}".format(op, interval) + # stage must be in order, such as Forward(0 1 2 3 4), Backward(4 3 2 1 0) + # if stage is unordered, such as Forward(0 1 2 3 4 3 4), will report error + if interval == -1: + decrease_flag = True + if interval == 1: + assert decrease_flag is False, \ + "Pipeline stage must be in order, " \ + "please check the stage of op={}".format(op) + pre_stage_id = stage_id + return device_list def _insert_sendrecv_ops_for_boundaries(self, block): @@ -4826,6 +4855,7 @@ class PipelineOptimizer(object): }) extra_index_info['index'] += 1 insert_index = None + if int(op_role) == int(self._op_role.Backward): insert_index = extra_index_info[ 'first_optimize_index'] @@ -4833,7 +4863,8 @@ class PipelineOptimizer(object): else: insert_index = index new_op_role = self._op_role.Backward - block._insert_op_without_sync( + + sync_comm_op = block._insert_op_without_sync( index=insert_index + extra_index_info['index'], type='c_sync_comm_stream', inputs={'X': [var]}, @@ -4843,8 +4874,11 @@ class PipelineOptimizer(object): self._op_role_key: new_op_role, 'ring_id': ring_id, }) + if int(op_role) == int(self._op_role.Forward): + sync_comm_op._set_attr('pipeline_flag', '') extra_index_info['index'] += 1 + var_shape = list(var.shape) var_shape[0] = self.micro_batch_size if var_shape[ 0] < 0 else var_shape[0] @@ -5153,17 +5187,55 @@ class PipelineOptimizer(object): Get info of op input and output. ''' # A map from output var to op which generate it. - self.output_var_to_op = dict() + output_var_to_op = defaultdict(list) # A map from var to op which takes it as input. - self.input_var_to_op = dict() + input_var_to_op = defaultdict(list) - for index, op in enumerate(list(block.ops)): + for index, op in enumerate(block.ops): for var_name in op.input_arg_names: - ops = self.input_var_to_op.setdefault(var_name, []) - ops.append([op, index]) + input_var_to_op[var_name].append([op, index]) for var_name in op.output_arg_names: - ops = self.output_var_to_op.setdefault(var_name, []) - ops.append([op, index]) + output_var_to_op[var_name].append([op, index]) + + return output_var_to_op, input_var_to_op + + def _optimize_forward_send_sync(self, program): + """ + optimize forward send's sync_comm_stream schedule + """ + if self.schedule_mode != '1F1B': return + + block = program.block(0) + + backward_recv_index = None + for index, op in enumerate(block.ops): + if op.type == 'recv_v2' and self._is_backward_op(op): + backward_recv_index = index + break + + if backward_recv_index is None: return + + offset = 0 + for index, op in enumerate(list(block.ops)): + if index >= backward_recv_index: break + if op.type == 'c_sync_comm_stream' and op.has_attr('pipeline_flag'): + var_name = op.input_arg_names[0] + var = block.var(var_name) + block._remove_op(index + offset, sync=False) + offset -= 1 + # NOTE: + # 1. When the backward recv is completed, it indicates + # that the forward send is completed too. So we only need + # to use the NOP op to prevent memory release. + # 2. Because we removed sync_comm_op, + # we will insert NOP after recv_op. + block._insert_op_without_sync( + index=backward_recv_index, + type='nop', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={self._op_role_key: self._op_role.Backward}) + block._sync_with_cpp() def minimize(self, loss, @@ -5200,7 +5272,8 @@ class PipelineOptimizer(object): loss, startup_program, parameter_list, no_grad_set) self._param_device_map = self._origin_optimizer._param_device_map - self._get_input_output_info(main_block) + self.output_var_to_op, self.input_var_to_op = \ + self._get_input_output_info(main_block) # Step1: add default op_device attribute for ops. self._add_op_device_attr(main_block) device_list = self._check_validation(main_block) @@ -5229,6 +5302,10 @@ class PipelineOptimizer(object): for p in program_list: self._create_vars(p.global_block(), main_block) + self.local_rank %= len(device_list) + # Step3.5: optimize forward send sync_comm to overlap send and recv + self._optimize_forward_send_sync(program_list[self.local_rank]) + # Step4: Special Case: process persistable vars that exist in # multiple sections # FIXME @@ -5238,7 +5315,6 @@ class PipelineOptimizer(object): # Step5: Add sub blocks for section programs self._add_sub_blocks(main_block, program_list) - self.local_rank %= len(device_list) place_list = [] for dev in device_list: dev_index = int(dev.split(":")[1])