未验证 提交 30845734 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] Recompute Pass (#38920)

* [AutoParallel] Recompute Pass

* update unittest

* reshard for amp

* add comment
上级 4aa91fd6
...@@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context): ...@@ -822,6 +822,28 @@ def complete_update_annotation(auto_parallel_main_prog, dist_context):
# TODO to add attribute for moment var # TODO to add attribute for moment var
op = ops[idx] op = ops[idx]
if int(op.attr('op_role')) == int(OpRole.Optimize): if int(op.attr('op_role')) == int(OpRole.Optimize):
if op.type == "clip_by_norm":
param_grad = vars[op.input("X")[0]]
param_grad_dist_attr = dist_context.get_tensor_dist_attr_for_program(
param_grad)
assert param_grad_dist_attr is not None
ref_process_mesh = param_grad_dist_attr.process_mesh
ref_dims_mapping = param_grad_dist_attr.dims_mapping
out = vars[op.output("Out")[0]]
out_dist_attr = TensorDistributedAttribute()
out_dist_attr.process_mesh = ref_process_mesh
out_dist_attr.dims_mapping = ref_dims_mapping
dist_context.set_tensor_dist_attr_for_program(out,
out_dist_attr)
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = ref_process_mesh
op_dist_attr.set_input_dist_attr(param_grad.name,
param_grad_dist_attr)
op_dist_attr.set_output_dist_attr(out.name, out_dist_attr)
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
if "Grad" in op.input_names and "Param" in ops[idx].input_names: if "Grad" in op.input_names and "Param" in ops[idx].input_names:
assert len(op.input( assert len(op.input(
......
...@@ -21,7 +21,9 @@ _g_tensor_dist_attr_field_keys = [ ...@@ -21,7 +21,9 @@ _g_tensor_dist_attr_field_keys = [
"process_mesh", "dims_mapping", "shard_sizes", "device_placement" "process_mesh", "dims_mapping", "shard_sizes", "device_placement"
] ]
_g_op_dist_attr_field_keys = ["process_mesh", "impl_type", "impl_idx"] _g_op_dist_attr_field_keys = [
"process_mesh", "impl_type", "impl_idx", "is_recompute"
]
_g_op_input_suffix = "@input" _g_op_input_suffix = "@input"
...@@ -178,6 +180,7 @@ class OperatorDistributedAttribute: ...@@ -178,6 +180,7 @@ class OperatorDistributedAttribute:
self._inputs_dist_attrs = {} self._inputs_dist_attrs = {}
self._outputs_dist_attrs = {} self._outputs_dist_attrs = {}
self._is_annotated = {} self._is_annotated = {}
self._is_recompute = False
@property @property
def process_mesh(self): def process_mesh(self):
...@@ -214,6 +217,15 @@ class OperatorDistributedAttribute: ...@@ -214,6 +217,15 @@ class OperatorDistributedAttribute:
if impl_idx is not None: if impl_idx is not None:
self._impl_idx = impl_idx self._impl_idx = impl_idx
@property
def is_recompute(self):
return self._is_recompute
@is_recompute.setter
def is_recompute(self, is_recompute):
assert isinstance(is_recompute, bool)
self._is_recompute = is_recompute
@property @property
def inputs_dist_attrs(self): def inputs_dist_attrs(self):
return self._inputs_dist_attrs return self._inputs_dist_attrs
......
...@@ -166,6 +166,13 @@ class DistributedContext: ...@@ -166,6 +166,13 @@ class DistributedContext:
else: else:
return None return None
def get_tensor_dist_attr_for_program_with_id(self, tensor_id):
dist_tensor = self._dist_tensors_for_program.get(tensor_id, None)
if dist_tensor:
return dist_tensor.dist_attr
else:
return None
def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr): def set_tensor_dist_attr_for_program(self, serial_tensor, dist_attr):
dist_tensor = DistributedTensor(serial_tensor, dist_attr) dist_tensor = DistributedTensor(serial_tensor, dist_attr)
self.add_dist_tensor_for_program(dist_tensor) self.add_dist_tensor_for_program(dist_tensor)
...@@ -192,6 +199,13 @@ class DistributedContext: ...@@ -192,6 +199,13 @@ class DistributedContext:
else: else:
return None return None
def get_op_dist_attr_for_program_with_id(self, op_id):
dist_op = self._dist_ops_for_program.get(op_id, None)
if dist_op:
return dist_op.dist_attr
else:
return None
def set_op_dist_attr_for_program(self, serial_op, dist_attr): def set_op_dist_attr_for_program(self, serial_op, dist_attr):
dist_op = DistributedOperator(serial_op, dist_attr) dist_op = DistributedOperator(serial_op, dist_attr)
self.add_dist_op_for_program(dist_op) self.add_dist_op_for_program(dist_op)
......
...@@ -99,6 +99,8 @@ class DistributedOperator: ...@@ -99,6 +99,8 @@ class DistributedOperator:
self._dist_attr.impl_type = "default" self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None: if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2 self._dist_attr.impl_idx = -2
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False
def _filter_dist_attr(self, dist_attr): def _filter_dist_attr(self, dist_attr):
if dist_attr is None: if dist_attr is None:
......
...@@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True): ...@@ -118,6 +118,8 @@ def find_best_compatible_distributed_operator_impl(name, dist_op, fwd=True):
def is_parameter_related(varname, block): def is_parameter_related(varname, block):
if ".subprog_" in varname:
varname = varname[:varname.index(".subprog_")]
if ".cast_fp" in varname: if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")] varname = varname[:varname.index(".cast_fp")]
assert block.has_var(varname) assert block.has_var(varname)
......
...@@ -216,6 +216,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -216,6 +216,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related( if "@GRAD" not in varname and is_parameter_related(
varname, main_block): varname, main_block):
# NOTE: When amp and recompute pass are effective at the same time,
# if a parameter is casted and recomputed, the 'parameter@GARD' can not
# be found in the grad_op's output.
if "subprog_" in varname:
varname = varname[:varname.index(".subprog_")]
assert len( assert len(
backward_op.desc.input(input_name) backward_op.desc.input(input_name)
) == 1, "parameter input to grad op should be length 1, but got [{}]".format( ) == 1, "parameter input to grad op should be length 1, but got [{}]".format(
......
...@@ -283,7 +283,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -283,7 +283,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr) allreduce_op_dist_attr)
# param initialization sync # param initialization sync
if Weight_var.is_parameter: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
assert Weight_var.name not in dist_op_context.already_init_sync_vars assert Weight_var.name not in dist_op_context.already_init_sync_vars
dist_op_context.already_init_sync_vars.add(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name) param = startup_block.var(Weight_var.name)
......
...@@ -680,7 +680,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -680,7 +680,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_op, matmul_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
...@@ -968,7 +968,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -968,7 +968,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr) allreduce_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
...@@ -1383,7 +1383,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1383,7 +1383,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr) ctx.set_op_dist_attr_for_program(matmul_v2_op, matmulv2_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
...@@ -1666,7 +1666,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1666,7 +1666,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
allreduce_op_dist_attr) allreduce_op_dist_attr)
# init param sync # init param sync
if Weight_var.is_parameter: if Weight_var.is_parameter and not op_dist_attr.is_recompute:
_init_param_sync(Weight_var, dist_op_context, startup_block, ctx, _init_param_sync(Weight_var, dist_op_context, startup_block, ctx,
rank_id) rank_id)
......
...@@ -83,9 +83,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -83,9 +83,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
assert 'Out' in kwargs, "output [{}] is not given".format('Out') assert 'Out' in kwargs, "output [{}] is not given".format('Out')
assert 'LossScaling' in kwargs, "output [{}] is not given".format( assert 'LossScaling' in kwargs, "output [{}] is not given".format(
'LossScaling') 'LossScaling')
assert 'OutGoodSteps' in kwargs, "input [{}] is not given".format( assert 'OutGoodSteps' in kwargs, "output [{}] is not given".format(
'OutGoodSteps') 'OutGoodSteps')
assert 'OutBadSteps' in kwargs, "input [{}] is not given".format( assert 'OutBadSteps' in kwargs, "output [{}] is not given".format(
'OutBadSteps') 'OutBadSteps')
assert len(kwargs['FoundInfinite']) == 1, \ assert len(kwargs['FoundInfinite']) == 1, \
......
...@@ -97,8 +97,8 @@ class AutoParallelizer: ...@@ -97,8 +97,8 @@ class AutoParallelizer:
if suffix in attr_name: if suffix in attr_name:
op._remove_attr(attr_name) op._remove_attr(attr_name)
def _apply_pre_optimization_passed(self, main_program, startup_program, def _apply_pre_optimization_passes(self, main_program, startup_program,
loss, params_grads): loss, params_grads, no_grad_set):
# apply amp pass # apply amp pass
if self._dist_strategy.amp: if self._dist_strategy.amp:
config = copy.deepcopy(self._dist_strategy.amp_configs) config = copy.deepcopy(self._dist_strategy.amp_configs)
...@@ -111,11 +111,14 @@ class AutoParallelizer: ...@@ -111,11 +111,14 @@ class AutoParallelizer:
# apply recompute pass # apply recompute pass
if self._dist_strategy.recompute: if self._dist_strategy.recompute:
auto_parallel_recompute_pass = new_pass( config = copy.deepcopy(self._dist_strategy.recompute_configs)
"auto_parallel_recompute_pass", config["dist_context"] = self._dist_context
self._dist_strategy.recompute_configs) config["no_grad_set"] = copy.deepcopy(no_grad_set)
auto_parallel_recompute_pass.apply(main_program, startup_program, config["loss"] = loss
self._pass_context) auto_parallel_recompute_pass = new_pass("auto_parallel_recompute",
config)
auto_parallel_recompute_pass.apply(
[main_program], [startup_program], self._pass_context)
def _generate_backward(self, main_program, startup_program, loss, def _generate_backward(self, main_program, startup_program, loss,
parameter_list, no_grad_set, callbacks): parameter_list, no_grad_set, callbacks):
...@@ -144,7 +147,7 @@ class AutoParallelizer: ...@@ -144,7 +147,7 @@ class AutoParallelizer:
return optimize_ops return optimize_ops
def _apply_post_optimization_passed(self, main_program, startup_program, def _apply_post_optimization_passes(self, main_program, startup_program,
rank, params_grads): rank, params_grads):
if self._dist_strategy.sharding: if self._dist_strategy.sharding:
...@@ -188,9 +191,9 @@ class AutoParallelizer: ...@@ -188,9 +191,9 @@ class AutoParallelizer:
self._parameter_list, self._no_grad_set, self._callbacks) self._parameter_list, self._no_grad_set, self._callbacks)
# serial forward pass # serial forward pass
self._apply_pre_optimization_passed(completed_main_program, self._apply_pre_optimization_passes(completed_main_program,
serial_startup_program, serial_loss, serial_startup_program, serial_loss,
params_grads) params_grads, self._no_grad_set)
# Logical partition # Logical partition
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
...@@ -207,7 +210,7 @@ class AutoParallelizer: ...@@ -207,7 +210,7 @@ class AutoParallelizer:
reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context) reshard(dist_main_prog, dist_startup_prog, rank, self._dist_context)
self._apply_post_optimization_passed(dist_main_prog, dist_startup_prog, self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
g_process_group_map = None g_process_group_map = None
if not relaunch_phase: if not relaunch_phase:
......
...@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di ...@@ -25,7 +25,7 @@ from paddle.distributed.auto_parallel.dist_context import DistributedContext, Di
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_recompute_op from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
...@@ -200,7 +200,8 @@ class Partitioner(object): ...@@ -200,7 +200,8 @@ class Partitioner(object):
serial_output_varname] = new_varname serial_output_varname] = new_varname
# partition op # partition op
if is_forward_op(op): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_forward_op(op) or op_dist_attr.is_recompute:
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement( dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context) op, self._dist_context)
...@@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -380,9 +381,9 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
# NOTE trick for dist ops that only have backward implement # NOTE trick for dist ops that only have backward implement
if backward_op.type in BACKWARD_ONLY_DIST_OPS: if backward_op.type in BACKWARD_ONLY_DIST_OPS:
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0 dist_op = get_distributed_operator_impl_container(backward_op.type)
return get_distributed_operator_impl_container( if dist_op and op_dist_attr.impl_idx >= 0:
backward_op.type).get_impl(op_dist_attr.impl_idx) return dist_op.get_impl(op_dist_attr.impl_idx)
dist_op = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
return dist_op.get_impl(0) return dist_op.get_impl(0)
......
...@@ -26,6 +26,9 @@ from .dist_context import DistributedContext ...@@ -26,6 +26,9 @@ from .dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .process_group import new_process_group, ProcessGroup, _g_process_group_map from .process_group import new_process_group, ProcessGroup, _g_process_group_map
# NOTE: If op in _g_special_ops, it will not be resharded.
_g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling']
class AllGatherOpDesc: class AllGatherOpDesc:
""" """
...@@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id, ...@@ -966,6 +969,17 @@ def reshard(auto_parallel_main_prog, auto_parallel_startup_prog, rank_id,
while idx < len(block.ops): while idx < len(block.ops):
pre_op_count = len(block.ops) pre_op_count = len(block.ops)
op = block.ops[idx] op = block.ops[idx]
def _is_special_op(op):
global _g_special_ops
if op.type in _g_special_ops:
return True
return False
if _is_special_op(op):
idx += 1
continue
dist_op = dist_context.get_dist_op_for_program(op) dist_op = dist_context.get_dist_op_for_program(op)
if dist_op is not None: if dist_op is not None:
idx_offset = 0 idx_offset = 0
......
...@@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context): ...@@ -1005,8 +1005,8 @@ def set_grad_var_shape(program, dist_context):
assert op_dist_attr is not None assert op_dist_attr is not None
for var_name in op.output_arg_names: for var_name in op.output_arg_names:
if "@GRAD" not in var_name:
assert "@GRAD" in var_name continue
forward_var_name = var_name[:var_name.find("@GRAD")] forward_var_name = var_name[:var_name.find("@GRAD")]
if op.type in [ if op.type in [
"c_allreduce_sum", "c_identity", "scale", "cast" "c_allreduce_sum", "c_identity", "scale", "cast"
...@@ -1076,11 +1076,6 @@ def is_backward_op(op): ...@@ -1076,11 +1076,6 @@ def is_backward_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward) int(op.all_attrs()[OP_ROLE_KEY]) & int(OpRole.Backward)
def is_recompute_op(op):
return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == 9
def is_loss_op(op): def is_loss_op(op):
return OP_ROLE_KEY in op.attr_names and \ return OP_ROLE_KEY in op.attr_names and \
int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss)) int(op.all_attrs()[OP_ROLE_KEY]) == (int(core.op_proto_and_checker_maker.OpRole.Forward) | int(core.op_proto_and_checker_maker.OpRole.Loss))
......
...@@ -17,6 +17,7 @@ from .fuse_all_reduce import * ...@@ -17,6 +17,7 @@ from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import * from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import * from .auto_parallel_sharding import *
from .auto_parallel_amp import * from .auto_parallel_amp import *
from .auto_parallel_recompute import *
from .cpp_pass import * from .cpp_pass import *
__all__ = [ __all__ = [
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
from .pass_base import PassBase, register_pass
from paddle.fluid import core, unique_name
from paddle.fluid import framework as framework
from paddle.fluid.framework import Variable, Operator
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
class RecomputeState(ProgramStats):
def __init__(self, block, ops):
super(RecomputeState, self).__init__(block=block, ops=ops)
self._block = block
self._ops = ops
self.var_op_deps = {}
def build_stats(self):
for i, op in enumerate(self._ops):
for name in op.desc.input_arg_names():
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_input_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = [i]
self.var_op_deps[name]["var_as_output_ops"] = []
for name in op.desc.output_arg_names():
if name in self.var_op_deps:
self.var_op_deps[name]["var_as_output_ops"].extend([i])
else:
self.var_op_deps[name] = {}
self.var_op_deps[name]["var_as_input_ops"] = []
self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints):
""" get recompute segments from checkpoints """
segments = []
start_idx = -1
pre_segment_end_idx = -1
while start_idx + 1 < len(checkpoints):
if start_idx == -1:
ckpt_name = checkpoints[start_idx + 1]
if ckpt_name not in self.var_op_deps:
start_idx += 1
continue
op_idx_list = self.var_op_deps[ckpt_name]["var_as_output_ops"]
if op_idx_list:
segments.append([0, max(op_idx_list) + 1])
else:
flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]])
if flag:
min_idx = self._update_segment_start(min_idx,
pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
else:
logging.info("Could not recompute op range [{}] - [{}] ".
format(min_idx, max_idx + 1))
start_idx += 1
for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i))
logging.info("segment start op: [{}]: [{}] [{}]".format(self._ops[
idx1].desc.type(), self._ops[idx1].desc.input_arg_names(
), self._ops[idx1].desc.output_arg_names()))
logging.info("segment end op: [{}]: [{}] [{}]".format(self._ops[
idx2 - 1].desc.type(), self._ops[idx2 - 1].desc.input_arg_names(
), self._ops[idx2 - 1].desc.output_arg_names()))
return segments
def modify_forward_desc_for_recompute(self, dist_context):
"""
If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs.
"""
op_types = [op.desc.type() for op in self._ops]
if "dropout" not in op_types:
return
op_idx = 0
while op_idx < len(self._ops):
cur_op = self._ops[op_idx]
if "grad" in cur_op.type:
break
if cur_op.type != "dropout":
op_idx += 1
continue
if cur_op.input("Seed") is not None and len(cur_op.input("Seed")):
op_idx += 1
continue
cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
# insert seed op to guarantee that two dropout op have the same outputs
op_unique_name = unique_name.generate("seed")
var_unique_name = unique_name.generate_with_ignorable_key(".".join(
[op_unique_name, 'tmp']))
seed_var = self._block.create_var(
name=var_unique_name,
dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False)
# set new seed_var's dist_attr
ref_dims_mapping = [-1]
ref_process_mesh = cur_op_dist_attr.process_mesh
seed_var_dist_attr = set_var_dist_attr(
dist_context, seed_var, ref_dims_mapping, ref_process_mesh)
seed = 0 if cur_op.attr("fix_seed") is False else int(
cur_op.attr("seed"))
seed_op = self._block._insert_op_without_sync(
index=cur_op.idx,
type="seed",
inputs={},
outputs={"Out": seed_var},
attrs={"seed": seed,
"force_cpu": True})
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context)
# modify dropout op's desc
self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name])
cur_op.desc.remove_attr("fix_seed")
cur_op.desc.remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name,
seed_var_dist_attr)
self._block._sync_with_cpp()
op_idx += 2
def _find_op_index(block, cur_op):
for idx in range(block.desc.op_size()):
if cur_op.desc == block.desc.op(idx):
return idx
return -1
def _get_stop_gradients(program, no_grad_set):
""" get no grad var """
if no_grad_set is None:
no_grad_set = set()
else:
no_grad_set = _get_no_grad_set_name(no_grad_set)
no_grad_set_name = set()
for var in program.list_vars():
assert isinstance(var, Variable)
if "@GRAD" in var.name:
break
if var.stop_gradient:
no_grad_set_name.add(_append_grad_suffix_(var.name))
no_grad_set_name.update(list(map(_append_grad_suffix_, no_grad_set)))
return no_grad_set_name
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars,
dist_context):
"""
Get the recomputed ops which will insert the backward part
"""
if len(descs) == 0:
return []
result_descs = []
op_role_attr_name = core.op_proto_and_checker_maker.kOpRoleAttrName()
backward = core.op_proto_and_checker_maker.OpRole.Backward
for desc in descs:
if isinstance(desc, framework.Operator):
desc = desc.desc
if isinstance(desc, tuple):
desc = desc[0]
is_needed = False
for name in desc.output_arg_names():
if main_block.has_var(name) and main_block.var(name).persistable:
continue
if name not in in_memory_vars:
is_needed = True
if is_needed:
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
set_dist_op_desc_original_id(new_op_desc, desc, dist_context)
new_op_desc._set_attr(op_role_attr_name, backward)
result_descs.append(new_op_desc)
return result_descs
@register_pass("auto_parallel_recompute")
class RecomputePass(PassBase):
def __init__(self):
super(RecomputePass, self).__init__()
self.set_attr("checkpoints", None)
self.set_attr("loss", None)
self.set_attr("dist_context", None)
self.set_attr("no_grad_set", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if self.get_attr("loss") is None:
return False
if self.get_attr("checkpoints") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_programs, startup_programs, context):
checkpoints = self.get_attr("checkpoints")
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
self._dist_context = self.get_attr("dist_context")
main_block = main_programs.global_block()
no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set)
# get op_path which is related to loss
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
# step 1: build recompute state
rc_state = RecomputeState(main_block, op_path)
rc_state.modify_forward_desc_for_recompute(self._dist_context)
rc_state.build_stats()
checkpoints = rc_state.sort_checkpoints(checkpoints)
segments = rc_state.get_recompute_segments(checkpoints)
if segments == []:
return
# step 2: get vars_should_be_hold
vars_should_be_hold = []
for segment in segments:
vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints)
logging.info("found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".
format(len(cross_vars), cross_vars))
vars_should_be_hold.extend(rc_state.get_reserved_vars())
vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold))
vars_in_memory = vars_should_be_hold + checkpoints
# step 3: get recomputed fwd ops desc
var_name_dict = {}
ckpt_ops_dict = {}
buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0]:segment[1]]
var_suffix = ".subprog_%d" % i
for op in fwd_ops:
input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names())
cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program(
op)
assert cur_op_dist_attr is not None
for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints:
continue
if name in vars_should_be_hold:
continue
if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names():
ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping(
name)
else:
ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping(
name)
# record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name
var_name_dict[name] = name + var_suffix
ref_var = main_block.var(name)
rc_var = main_block.create_var(
name=var_name_dict[name],
shape=ref_var.shape,
dtype=ref_var.dtype,
type=ref_var.type,
persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient)
# set new recomputed var's dist attr
set_var_dist_attr(self._dist_context, rc_var,
ref_dims_mapping, ref_process_mesh)
# get recomputed segment's descs
segment_descs = _add_needed_descs_to_block(
fwd_ops, buffer_block, main_block, vars_in_memory,
self._dist_context)
# rename recomputed ops' input and output var name
for key in var_name_dict:
_rename_arg_(segment_descs, key, var_name_dict[key])
# NOTE: one forward op could be correspond to multiple xxx_grad op.
# When traversing all grad_ops in reverse, need to set a flag to indicate
# whether the ckpt and its segment_descs can be used.
ckpt_op = op_path[segment[1] - 1]
ckpt_ops_dict[ckpt_op.desc.id()] = [True, segment_descs]
# step 4: insert recomputed fwd ops
ops = main_block.ops
loss_op = get_loss_op(main_block)
loss_op_idx = _find_op_index(main_block, loss_op)
dist_op_context = self._dist_context.dist_op_context
assert loss_op_idx != -1
# Traversing all grad_ops in reverse, and if the fwd op corresponding to reverse op is checkpoints,
# segments ops should be inserted.
for i in range(len(ops) - 1, loss_op_idx, -1):
grad_op = ops[i]
# remove some attrs of dropout_grad op's desc
if grad_op.type == "dropout_grad":
grad_op.desc.remove_attr("fix_seed")
grad_op.desc.remove_attr("seed")
main_block._sync_with_cpp()
# rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict:
self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key])
# insert recomputed ops
if grad_op.desc.id() in dist_op_context.grad_op_id_to_op_id:
fwd_op_id = dist_op_context.grad_op_id_to_op_id[grad_op.desc.id(
)]
if fwd_op_id in ckpt_ops_dict and ckpt_ops_dict[fwd_op_id][0]:
idx = grad_op.idx
while idx - 1 >= 0 and ops[idx - 1].type == "sum":
idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_desc = main_block.desc._insert_op(idx)
rc_desc.copy_from(op_desc)
rc_op = Operator(main_block, rc_desc)
main_block.ops.insert(idx, rc_op)
# set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
rc_desc.original_id())
assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr,
var_name_dict)
ckpt_ops_dict[fwd_op_id][0] = False
main_block._sync_with_cpp()
main_programs._sync_with_cpp()
def reset_op_dist_attr(self, op, var_name_dict):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
assert op_dist_attr is not None
for input in op.desc.input_arg_names():
if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr(var_name_dict[input],
in_dist_attr)
for output in op.desc.output_arg_names():
if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr(var_name_dict[output],
out_dist_attr)
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
new_dist_attr = OperatorDistributedAttribute()
new_dist_attr.is_recompute = True
new_dist_attr.impl_idx = old_dist_attr.impl_idx
new_dist_attr.process_mesh = old_dist_attr.process_mesh
for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys():
in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(var_name_dict[input],
in_dist_attr)
else:
in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(input, in_dist_attr)
for output in old_dist_attr.outputs_dist_attrs.keys():
if output in var_name_dict.keys():
out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(var_name_dict[output],
out_dist_attr)
else:
out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(output, out_dist_attr)
self._dist_context.set_op_dist_attr_for_program(op, new_dist_attr)
...@@ -894,7 +894,6 @@ class GPTModel(nn.Layer): ...@@ -894,7 +894,6 @@ class GPTModel(nn.Layer):
"dims_mapping": "dims_mapping":
[0] + [-1 for i in range(len(input_ids.shape) - 1)] [0] + [-1 for i in range(len(input_ids.shape) - 1)]
}) })
attention_mask.stop_gradient = True
encoder_outputs = self.decoder( encoder_outputs = self.decoder(
embedding_output, embedding_output,
memory=None, memory=None,
......
...@@ -110,14 +110,8 @@ class AutoPallelPassTestBase(DistPassTestBase): ...@@ -110,14 +110,8 @@ class AutoPallelPassTestBase(DistPassTestBase):
elif strategy == "mp": elif strategy == "mp":
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]), auto.ProcessMesh(mesh=[1])
]
else: else:
raise ValueError("'get_gpt_model' only support dp, mp and pp.") raise ValueError("'get_gpt_model' only support dp and mp.")
tokens = paddle.static.data( tokens = paddle.static.data(
name="tokens", shape=[batch_size, sequence_len], dtype='int64') name="tokens", shape=[batch_size, sequence_len], dtype='int64')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.nn as nn
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.passes import new_pass, PassManager
from auto_parallel_pass_test_base import AutoPallelPassTestBase
class TestRecomputePass(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-6
self.atol = 1e-8
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.recompute = True
dist_strategy.recompute_configs = {"checkpoints": ["tmp3", "tmp6"]}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
def test_bs_8(self):
self.check_main(
gpus=[0, 1], batch_size=8, sequence_len=512, vocab_size=1000)
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("mp", place, batch_size, sequence_len,
vocab_size)
class TestRecomputePassDP(TestRecomputePass):
def get_model(self, place, batch_size, sequence_len, vocab_size):
return self.get_gpt_model("dp", place, batch_size, sequence_len,
vocab_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册