未验证 提交 3869a3b4 编写于 作者: J JZ-LIANG 提交者: GitHub

Cherry Pick Random Ctrl (#52778)

上级 6959eae5
...@@ -19,5 +19,6 @@ from .interface import shard_tensor ...@@ -19,5 +19,6 @@ from .interface import shard_tensor
from .interface import shard_op from .interface import shard_op
from .interface import recompute from .interface import recompute
from .interface import fetch from .interface import fetch
from .random import parallel_manual_seed
__all__ = [] __all__ = []
...@@ -51,7 +51,7 @@ from .dist_loader import ( ...@@ -51,7 +51,7 @@ from .dist_loader import (
from .process_group import new_process_group, get_all_process_groups from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy from .strategy import Strategy
from .interface import CollectionNames, get_collection from .interface import CollectionNames, get_collection, fetch
from .utils import to_list, get_dist_attr, get_lr, validate_opt from .utils import to_list, get_dist_attr, get_lr, validate_opt
from .utils import initialize_pg_in_full_mode, get_input_split_info from .utils import initialize_pg_in_full_mode, get_input_split_info
from .cost.estimate_cost import get_cost_from_engine from .cost.estimate_cost import get_cost_from_engine
...@@ -438,6 +438,8 @@ class Engine: ...@@ -438,6 +438,8 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -464,10 +466,13 @@ class Engine: ...@@ -464,10 +466,13 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES) item[1] for item in get_collection(CollectionNames.FETCHES)
] ]
var_list = (user_fetches_collection or []) + (user_fetches or []) var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list) _process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices return fetch_names, fetch_indices
...@@ -522,10 +527,10 @@ class Engine: ...@@ -522,10 +527,10 @@ class Engine:
# logging user fetches # logging user fetches
collect_fetches = get_collection(CollectionNames.FETCHES) collect_fetches = get_collection(CollectionNames.FETCHES)
logs_fetch = {} logs_fetch = {}
for name, var in collect_fetches: for name, var_name in collect_fetches:
if var.name in fetch_names: if var_name in fetch_names:
idx = fetch_names.index(var.name) idx = fetch_names.index(var_name)
logs_fetch[name or var.name] = outs[idx] logs_fetch[name or var_name] = outs[idx]
logs["fetches"] = logs_fetch logs["fetches"] = logs_fetch
return logs return logs
......
...@@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None): ...@@ -258,6 +258,16 @@ def add_to_collection(collection_name, value, name=None):
def fetch(tensor, name=None, logging=False): def fetch(tensor, name=None, logging=False):
if isinstance(tensor, paddle.fluid.framework.Variable):
tensor = tensor.name
elif isinstance(tensor, str):
tensor = tensor
else:
raise TypeError(
"Only support fetch `Variable` or `str`[`Variable`'s name], but got `{}`".format(
type(tensor)
)
)
add_to_collection(CollectionNames.FETCHES, tensor, name) add_to_collection(CollectionNames.FETCHES, tensor, name)
if logging: if logging:
add_to_collection(CollectionNames.LOGGING, tensor, name) add_to_collection(CollectionNames.LOGGING, tensor, name)
...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p ...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p
from . import dist_shape from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedDropout(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedDropout("dropout"))
# Dist Dropout with Random Control
# Dropout re-use the compatible and cost function of elementwise
class DistributedDropoutImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
len(kwargs['X']) == 1
), "input X should be only one tensor but got {}".format(
kwargs['X']
)
assert 'Seed' in kwargs, "input [{}] is not given".format('Seed')
if (
src_op.has_attr("fix_seed")
and src_op.attr("fix_seed")
and src_op.has_attr("seed")
and src_op.attr("seed")
):
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
elif rank_id not in op_dist_attr.process_mesh.process_ids:
pass
# NOTE Adopt for recompute
# If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control.
# TODO in future recompute pass should happen after parallel partitione. and remove this at that time.
elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0:
seed_var_name = kwargs['Seed'][0]
if seed_var_name.startswith('rc_seed'):
pre_op = main_block.ops[-1]
assert (
pre_op.type == "seed"
and len(pre_op.attr("rng_name")) == 0
), f"found exception op {str(pre_op)}"
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(
X_var.name
)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
# make recompute seed under control
pre_op._set_attr("rng_name", rng_name)
pre_op._set_attr("deterministic", True)
pre_op._set_attr("force_cpu", True)
else:
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
else:
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
assert rng_name is not None and rng_name != ""
# insert seed op
seed_var = main_block.create_var(
name=paddle.fluid.unique_name.generate_with_ignorable_key(
".".join(["tensor_parallel_seed", 'tmp'])
),
dtype=paddle.int32,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
)
# adopt for recompute
# force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
seed_op = main_block.append_op(
type='seed',
outputs={'Out': seed_var},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True,
},
)
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
)
# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
kwargs['Seed'] = [seed_var.name]
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"dropout", DistributedDropoutImpl0("random_control")
)
...@@ -25,6 +25,7 @@ from .reshard import Resharder ...@@ -25,6 +25,7 @@ from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
...@@ -84,6 +85,9 @@ class Parallelizer: ...@@ -84,6 +85,9 @@ class Parallelizer:
) = partitioner.partition( ) = partitioner.partition(
serial_main_program, serial_startup_program, params_grads serial_main_program, serial_startup_program, params_grads
) )
init_auto_parallel_rng()
self._logger.debug( self._logger.debug(
"within parallel partitioner time: {}, mode {}".format( "within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode time.time() - time0, self._mode
......
...@@ -19,6 +19,8 @@ import paddle ...@@ -19,6 +19,8 @@ import paddle
# Use to store the previous and current process mesh # Use to store the previous and current process mesh
_g_previous_process_mesh = None _g_previous_process_mesh = None
_g_current_process_mesh = None _g_current_process_mesh = None
# {shape_process_ids : unique_id}
_g_unique_process_mesh_map = {}
def get_current_process_mesh(): def get_current_process_mesh():
...@@ -39,6 +41,30 @@ def reset_current_process_mesh(): ...@@ -39,6 +41,30 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh _g_current_process_mesh = _g_previous_process_mesh
def get_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
if key in _g_unique_process_mesh_map:
unique_id = _g_unique_process_mesh_map[key]
else:
unique_id = len(_g_unique_process_mesh_map) + 1
_g_unique_process_mesh_map[key] = unique_id
return unique_id
def retrive_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
assert key in _g_unique_process_mesh_map
return _g_unique_process_mesh_map[key]
def get_unique_process_mesh_map():
global _g_unique_process_mesh_map
return _g_unique_process_mesh_map
class ProcessMesh(object): class ProcessMesh(object):
""" """
The `Processmesh` object describes the topology of the used processes. The `Processmesh` object describes the topology of the used processes.
...@@ -113,6 +139,11 @@ class ProcessMesh(object): ...@@ -113,6 +139,11 @@ class ProcessMesh(object):
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.processes) pg0.add_ranks(self.processes)
# Uniqe Mesh Id
self._unique_id = get_unique_id_for_process_mesh(
self._shape, self._process_ids
)
@property @property
def shape(self): def shape(self):
""" """
...@@ -148,6 +179,16 @@ class ProcessMesh(object): ...@@ -148,6 +179,16 @@ class ProcessMesh(object):
""" """
return self._mesh return self._mesh
@property
def unique_id(self):
"""
Get the unique id of ProcessMesh.
NOTE
Unique id only take process_ids and shape into account.
Different ProcessMesh with same process_ids and shape have same unique id.
"""
return self._unique_id
@property @property
def topology(self): def topology(self):
return self._shape return self._shape
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis
_logger = get_logger(logging.INFO)
_rng_name_to_seed = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
# use Prime number as offset to avoid confict
_mesh_offset = 173
_dim_offsets = [11, 23, 37, 73]
def is_enable_auto_rand_ctrl():
global _enable_random_control
return _enable_random_control
def enable_auto_rand_ctrl():
global _enable_random_control
_enable_random_control = True
def parallel_manual_seed(seed):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
* Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent.
For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1].
Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent;
Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent.
This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()).
This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization.
Examples:
# seed relative to training step
auto_parallel_random_seed((step + 13) * 257)
...
engine.prepare()
"""
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
def determinate_rng(rank, dims_mapping, process_mesh):
# TODO(JZ-LIANG) Support Mesh with any high rank
# use a string to unique integer hashing algorithm for seed computation.
# instead of using offsets to coodinate seed across devices.
if len(process_mesh.shape) > 4:
raise NotImplementedError(
"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format(
str(process_mesh)
)
)
global _basic_seed
seed_ = _basic_seed
# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)
for i in range(len(process_mesh.shape)):
if i not in dims_mapping:
relative_idx = -1
else:
relative_idx = _get_idx_in_axis(
process_mesh.process_ids,
process_mesh.shape,
i,
rank,
)
sharding_expr += f"_dim{i}:{relative_idx}"
seed_ += _dim_offsets[i] * (relative_idx + 1)
global _rng_name_to_seed
if sharding_expr in _rng_name_to_seed:
assert _rng_name_to_seed[sharding_expr] == seed_
else:
assert (
seed_ not in _rng_name_to_seed.values()
), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format(
seed_, sharding_expr, _rng_name_to_seed
)
_rng_name_to_seed[sharding_expr] = seed_
return sharding_expr
def init_auto_parallel_rng():
if not is_enable_auto_rand_ctrl():
return
global _rng_name_to_seed
# NOTE init rng maybe call multiple times, avoid init same rng twice
global _inited_rng_name_to_seed
for rng_name, seed in _rng_name_to_seed.items():
if rng_name in _inited_rng_name_to_seed:
assert _inited_rng_name_to_seed[rng_name] == seed
else:
_logger.info(
f"Init Auto Parallel RNG: {rng_name}, with seed {seed}"
)
paddle.framework.random.set_random_seed_generator(rng_name, seed)
_inited_rng_name_to_seed[rng_name] = seed
...@@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator ...@@ -22,13 +22,20 @@ from paddle.fluid.framework import Variable, Operator
from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name from paddle.fluid.backward import _append_grad_suffix_, _get_no_grad_set_name
from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_ from paddle.fluid.backward import ProgramStats, _rename_arg_, _find_op_path_
from paddle.distributed.auto_parallel.process_mesh import ProcessMesh from paddle.distributed.auto_parallel.process_mesh import ProcessMesh
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute from paddle.distributed.auto_parallel.dist_attribute import (
from paddle.distributed.auto_parallel.utils import get_loss_op, set_var_dist_attr, set_dist_op_desc_original_id OperatorDistributedAttribute,
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping )
from paddle.distributed.auto_parallel.utils import (
get_loss_op,
set_var_dist_attr,
set_dist_op_desc_original_id,
)
from paddle.distributed.auto_parallel.utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
class RecomputeState(ProgramStats): class RecomputeState(ProgramStats):
def __init__(self, block, ops): def __init__(self, block, ops):
super(RecomputeState, self).__init__(block=block, ops=ops) super(RecomputeState, self).__init__(block=block, ops=ops)
self._block = block self._block = block
...@@ -54,7 +61,7 @@ class RecomputeState(ProgramStats): ...@@ -54,7 +61,7 @@ class RecomputeState(ProgramStats):
self.var_op_deps[name]["var_as_output_ops"] = [i] self.var_op_deps[name]["var_as_output_ops"] = [i]
def get_recompute_segments(self, checkpoints): def get_recompute_segments(self, checkpoints):
""" get recompute segments from checkpoints """ """get recompute segments from checkpoints"""
segments = [] segments = []
start_idx = -1 start_idx = -1
pre_segment_end_idx = -1 pre_segment_end_idx = -1
...@@ -69,33 +76,43 @@ class RecomputeState(ProgramStats): ...@@ -69,33 +76,43 @@ class RecomputeState(ProgramStats):
segments.append([0, max(op_idx_list) + 1]) segments.append([0, max(op_idx_list) + 1])
else: else:
flag, min_idx, max_idx = self.is_subgraph( flag, min_idx, max_idx = self.is_subgraph(
[checkpoints[start_idx]], [checkpoints[start_idx + 1]]) [checkpoints[start_idx]], [checkpoints[start_idx + 1]]
)
if flag: if flag:
min_idx = self._update_segment_start( min_idx = self._update_segment_start(
min_idx, pre_segment_end_idx) min_idx, pre_segment_end_idx
)
segments.append([min_idx, max_idx + 1]) segments.append([min_idx, max_idx + 1])
else: else:
logging.info( logging.info(
"Could not recompute op range [{}] - [{}] ".format( "Could not recompute op range [{}] - [{}] ".format(
min_idx, max_idx + 1)) min_idx, max_idx + 1
)
)
start_idx += 1 start_idx += 1
for i, (idx1, idx2) in enumerate(segments): for i, (idx1, idx2) in enumerate(segments):
logging.info("recompute segment[{}]".format(i)) logging.info("recompute segment[{}]".format(i))
logging.info("segment start op: [{}]: [{}] [{}]".format( logging.info(
self._ops[idx1].desc.type(), "segment start op: [{}]: [{}] [{}]".format(
self._ops[idx1].desc.input_arg_names(), self._ops[idx1].desc.type(),
self._ops[idx1].desc.output_arg_names())) self._ops[idx1].desc.input_arg_names(),
logging.info("segment end op: [{}]: [{}] [{}]".format( self._ops[idx1].desc.output_arg_names(),
self._ops[idx2 - 1].desc.type(), )
self._ops[idx2 - 1].desc.input_arg_names(), )
self._ops[idx2 - 1].desc.output_arg_names())) logging.info(
"segment end op: [{}]: [{}] [{}]".format(
self._ops[idx2 - 1].desc.type(),
self._ops[idx2 - 1].desc.input_arg_names(),
self._ops[idx2 - 1].desc.output_arg_names(),
)
)
return segments return segments
def modify_forward_desc_for_recompute(self, dist_context): def modify_forward_desc_for_recompute(self, dist_context):
""" """
If program's foward part has 'dropout' op, this function will insert If program's foward part has 'dropout' op, this function will insert
a seed op before it to guarantee that two dropout op have the same outputs. a seed op before it to guarantee that two dropout op have the same outputs.
""" """
op_types = [op.desc.type() for op in self._ops] op_types = [op.desc.type() for op in self._ops]
...@@ -116,45 +133,52 @@ class RecomputeState(ProgramStats): ...@@ -116,45 +133,52 @@ class RecomputeState(ProgramStats):
cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
# insert seed op to guarantee that two dropout op have the same outputs # insert seed op to guarantee that two dropout op have the same outputs
op_unique_name = unique_name.generate("seed") # NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py
var_unique_name = unique_name.generate_with_ignorable_key(".".join( # new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule.
[op_unique_name, 'tmp'])) op_unique_name = unique_name.generate("rc_seed")
var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp'])
)
seed_var = self._block.create_var( seed_var = self._block.create_var(
name=var_unique_name, name=var_unique_name,
dtype='int32', dtype='int32',
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
stop_gradient=False) stop_gradient=False,
)
# set new seed_var's dist_attr # set new seed_var's dist_attr
ref_dims_mapping = [-1] ref_dims_mapping = [-1]
ref_process_mesh = cur_op_dist_attr.process_mesh ref_process_mesh = cur_op_dist_attr.process_mesh
seed_var_dist_attr = set_var_dist_attr(dist_context, seed_var, seed_var_dist_attr = set_var_dist_attr(
ref_dims_mapping, dist_context, seed_var, ref_dims_mapping, ref_process_mesh
ref_process_mesh) )
seed = 0 if cur_op.attr("fix_seed") is False else int( seed = (
cur_op.attr("seed")) 0
if cur_op.attr("fix_seed") is False
else int(cur_op.attr("seed"))
)
seed_op = self._block._insert_op_without_sync( seed_op = self._block._insert_op_without_sync(
index=cur_op.idx, index=cur_op.idx,
type="seed", type="seed",
inputs={}, inputs={},
outputs={"Out": seed_var}, outputs={"Out": seed_var},
attrs={ attrs={"seed": seed, "force_cpu": True},
"seed": seed, )
"force_cpu": True
})
# set new seed op's dist_attr # set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, ref_process_mesh, ref_dims_mapping, dist_context) seed_op, ref_process_mesh, ref_dims_mapping, dist_context
)
# modify dropout op's desc # modify dropout op's desc
self._ops.insert(op_idx, seed_op) self._ops.insert(op_idx, seed_op)
cur_op.desc.set_input("Seed", [var_unique_name]) cur_op.desc.set_input("Seed", [var_unique_name])
cur_op._remove_attr("fix_seed") cur_op._remove_attr("fix_seed")
cur_op._remove_attr("seed") cur_op._remove_attr("seed")
cur_op_dist_attr.set_input_dist_attr(seed_var.name, cur_op_dist_attr.set_input_dist_attr(
seed_var_dist_attr) seed_var.name, seed_var_dist_attr
)
op_idx += 2 op_idx += 2
self._block._sync_with_cpp() self._block._sync_with_cpp()
...@@ -168,7 +192,7 @@ def _find_op_index(block, cur_op): ...@@ -168,7 +192,7 @@ def _find_op_index(block, cur_op):
def _get_stop_gradients(program, no_grad_set): def _get_stop_gradients(program, no_grad_set):
""" get no grad var """ """get no grad var"""
if no_grad_set is None: if no_grad_set is None:
no_grad_set = set() no_grad_set = set()
else: else:
...@@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set): ...@@ -185,8 +209,9 @@ def _get_stop_gradients(program, no_grad_set):
return no_grad_set_name return no_grad_set_name
def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, def _add_needed_descs_to_block(
dist_context): descs, block, main_block, in_memory_vars, dist_context
):
""" """
Get the recomputed ops which will insert the backward part Get the recomputed ops which will insert the backward part
""" """
...@@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars, ...@@ -217,7 +242,6 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars,
@register_pass("auto_parallel_recompute") @register_pass("auto_parallel_recompute")
class RecomputePass(PassBase): class RecomputePass(PassBase):
def __init__(self): def __init__(self):
super(RecomputePass, self).__init__() super(RecomputePass, self).__init__()
self.set_attr("checkpoints", None) self.set_attr("checkpoints", None)
...@@ -261,12 +285,15 @@ class RecomputePass(PassBase): ...@@ -261,12 +285,15 @@ class RecomputePass(PassBase):
vars_should_be_hold = [] vars_should_be_hold = []
for segment in segments: for segment in segments:
vars_should_be_hold.extend( vars_should_be_hold.extend(
rc_state.get_out_of_subgraph_vars(segment[0], segment[1])) rc_state.get_out_of_subgraph_vars(segment[0], segment[1])
)
cross_vars = set(vars_should_be_hold) - set(checkpoints) cross_vars = set(vars_should_be_hold) - set(checkpoints)
logging.info( logging.info(
"found [{}] vars which cross recompute segment: [{}]," "found [{}] vars which cross recompute segment: [{}],"
"better checkpoints might be set to reduce those vars".format( "better checkpoints might be set to reduce those vars".format(
len(cross_vars), cross_vars)) len(cross_vars), cross_vars
)
)
vars_should_be_hold.extend(rc_state.get_reserved_vars()) vars_should_be_hold.extend(rc_state.get_reserved_vars())
vars_should_be_hold.extend(rc_state.get_input_nodes()) vars_should_be_hold.extend(rc_state.get_input_nodes())
vars_should_be_hold = list(set(vars_should_be_hold)) vars_should_be_hold = list(set(vars_should_be_hold))
...@@ -277,14 +304,15 @@ class RecomputePass(PassBase): ...@@ -277,14 +304,15 @@ class RecomputePass(PassBase):
ckpt_ops_dict = {} ckpt_ops_dict = {}
buffer_block = main_block.program._create_block() buffer_block = main_block.program._create_block()
for i, segment in enumerate(segments[::-1]): for i, segment in enumerate(segments[::-1]):
fwd_ops = op_path[segment[0]:segment[1]] fwd_ops = op_path[segment[0] : segment[1]]
var_suffix = ".subprog_%d" % i var_suffix = ".subprog_%d" % i
for op in fwd_ops: for op in fwd_ops:
input_and_output_names = [] input_and_output_names = []
input_and_output_names.extend(op.desc.input_arg_names()) input_and_output_names.extend(op.desc.input_arg_names())
input_and_output_names.extend(op.desc.output_arg_names()) input_and_output_names.extend(op.desc.output_arg_names())
cur_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( cur_op_dist_attr = (
op) self._dist_context.get_op_dist_attr_for_program(op)
)
assert cur_op_dist_attr is not None assert cur_op_dist_attr is not None
for name in input_and_output_names: for name in input_and_output_names:
if main_block.var(name).persistable or name in checkpoints: if main_block.var(name).persistable or name in checkpoints:
...@@ -294,11 +322,13 @@ class RecomputePass(PassBase): ...@@ -294,11 +322,13 @@ class RecomputePass(PassBase):
if name not in var_name_dict: if name not in var_name_dict:
ref_process_mesh = cur_op_dist_attr.process_mesh ref_process_mesh = cur_op_dist_attr.process_mesh
if name in op.desc.input_arg_names(): if name in op.desc.input_arg_names():
ref_dims_mapping = cur_op_dist_attr.get_input_dims_mapping( ref_dims_mapping = (
name) cur_op_dist_attr.get_input_dims_mapping(name)
)
else: else:
ref_dims_mapping = cur_op_dist_attr.get_output_dims_mapping( ref_dims_mapping = (
name) cur_op_dist_attr.get_output_dims_mapping(name)
)
# record recomputed var's old_name and new_name (old_name.subprog_XXX) # record recomputed var's old_name and new_name (old_name.subprog_XXX)
# create new var with new name # create new var with new name
var_name_dict[name] = name + var_suffix var_name_dict[name] = name + var_suffix
...@@ -309,15 +339,23 @@ class RecomputePass(PassBase): ...@@ -309,15 +339,23 @@ class RecomputePass(PassBase):
dtype=ref_var.dtype, dtype=ref_var.dtype,
type=ref_var.type, type=ref_var.type,
persistable=ref_var.persistable, persistable=ref_var.persistable,
stop_gradient=ref_var.stop_gradient) stop_gradient=ref_var.stop_gradient,
)
# set new recomputed var's dist attr # set new recomputed var's dist attr
set_var_dist_attr(self._dist_context, rc_var, set_var_dist_attr(
ref_dims_mapping, ref_process_mesh) self._dist_context,
rc_var,
ref_dims_mapping,
ref_process_mesh,
)
# get recomputed segment's descs # get recomputed segment's descs
segment_descs = _add_needed_descs_to_block(fwd_ops, buffer_block, segment_descs = _add_needed_descs_to_block(
main_block, fwd_ops,
vars_in_memory, buffer_block,
self._dist_context) main_block,
vars_in_memory,
self._dist_context,
)
# rename recomputed ops' input and output var name # rename recomputed ops' input and output var name
for key in var_name_dict: for key in var_name_dict:
_rename_arg_(segment_descs, key, var_name_dict[key]) _rename_arg_(segment_descs, key, var_name_dict[key])
...@@ -345,7 +383,10 @@ class RecomputePass(PassBase): ...@@ -345,7 +383,10 @@ class RecomputePass(PassBase):
# rename grad op's var_name which is not in 'vars_in_memory' # rename grad op's var_name which is not in 'vars_in_memory'
for key in var_name_dict: for key in var_name_dict:
if key not in grad_op.input_arg_names + grad_op.output_arg_names: if (
key
not in grad_op.input_arg_names + grad_op.output_arg_names
):
continue continue
self.reset_op_dist_attr(grad_op, var_name_dict) self.reset_op_dist_attr(grad_op, var_name_dict)
_rename_arg_([grad_op.desc], key, var_name_dict[key]) _rename_arg_([grad_op.desc], key, var_name_dict[key])
...@@ -360,17 +401,20 @@ class RecomputePass(PassBase): ...@@ -360,17 +401,20 @@ class RecomputePass(PassBase):
idx -= 1 idx -= 1
segment_descs = ckpt_ops_dict[fwd_op_id][1] segment_descs = ckpt_ops_dict[fwd_op_id][1]
for _, op_desc in reversed(list(enumerate(segment_descs))): for _, op_desc in reversed(list(enumerate(segment_descs))):
rc_op = main_block._insert_op_without_sync(idx, rc_op = main_block._insert_op_without_sync(
type='nop') idx, type='nop'
)
rc_desc = rc_op.desc rc_desc = rc_op.desc
rc_desc.copy_from(op_desc) rc_desc.copy_from(op_desc)
rc_desc.set_original_id(rc_desc.id()) rc_desc.set_original_id(rc_desc.id())
# set recomputed ops' dist attr # set recomputed ops' dist attr
fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id( fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program_with_id(
op_desc.original_id()) op_desc.original_id()
)
assert fwd_op_dist_attr is not None assert fwd_op_dist_attr is not None
self.set_op_dist_attr(rc_op, fwd_op_dist_attr, self.set_op_dist_attr(
var_name_dict) rc_op, fwd_op_dist_attr, var_name_dict
)
ckpt_ops_dict[fwd_op_id][0] = False ckpt_ops_dict[fwd_op_id][0] = False
...@@ -382,13 +426,15 @@ class RecomputePass(PassBase): ...@@ -382,13 +426,15 @@ class RecomputePass(PassBase):
for input in op.desc.input_arg_names(): for input in op.desc.input_arg_names():
if input in var_name_dict.keys(): if input in var_name_dict.keys():
in_dist_attr = op_dist_attr.get_input_dist_attr(input) in_dist_attr = op_dist_attr.get_input_dist_attr(input)
op_dist_attr.set_input_dist_attr(var_name_dict[input], op_dist_attr.set_input_dist_attr(
in_dist_attr) var_name_dict[input], in_dist_attr
)
for output in op.desc.output_arg_names(): for output in op.desc.output_arg_names():
if output in var_name_dict.keys(): if output in var_name_dict.keys():
out_dist_attr = op_dist_attr.get_output_dist_attr(output) out_dist_attr = op_dist_attr.get_output_dist_attr(output)
op_dist_attr.set_output_dist_attr(var_name_dict[output], op_dist_attr.set_output_dist_attr(
out_dist_attr) var_name_dict[output], out_dist_attr
)
def set_op_dist_attr(self, op, old_dist_attr, var_name_dict): def set_op_dist_attr(self, op, old_dist_attr, var_name_dict):
new_dist_attr = OperatorDistributedAttribute() new_dist_attr = OperatorDistributedAttribute()
...@@ -399,16 +445,18 @@ class RecomputePass(PassBase): ...@@ -399,16 +445,18 @@ class RecomputePass(PassBase):
for input in old_dist_attr.inputs_dist_attrs.keys(): for input in old_dist_attr.inputs_dist_attrs.keys():
if input in var_name_dict.keys(): if input in var_name_dict.keys():
in_dist_attr = old_dist_attr.inputs_dist_attrs[input] in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(var_name_dict[input], new_dist_attr.set_input_dist_attr(
in_dist_attr) var_name_dict[input], in_dist_attr
)
else: else:
in_dist_attr = old_dist_attr.inputs_dist_attrs[input] in_dist_attr = old_dist_attr.inputs_dist_attrs[input]
new_dist_attr.set_input_dist_attr(input, in_dist_attr) new_dist_attr.set_input_dist_attr(input, in_dist_attr)
for output in old_dist_attr.outputs_dist_attrs.keys(): for output in old_dist_attr.outputs_dist_attrs.keys():
if output in var_name_dict.keys(): if output in var_name_dict.keys():
out_dist_attr = old_dist_attr.outputs_dist_attrs[output] out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(var_name_dict[output], new_dist_attr.set_output_dist_attr(
out_dist_attr) var_name_dict[output], out_dist_attr
)
else: else:
out_dist_attr = old_dist_attr.outputs_dist_attrs[output] out_dist_attr = old_dist_attr.outputs_dist_attrs[output]
new_dist_attr.set_output_dist_attr(output, out_dist_attr) new_dist_attr.set_output_dist_attr(output, out_dist_attr)
......
...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_high_order_grad set_tests_properties(test_high_order_grad
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_random_ctrl MODULES test_random_ctrl ENVS ${dist_ENVS})
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS py_test_modules(test_iterable_dataset MODULES test_iterable_dataset ENVS
${dist_ENVS}) ${dist_ENVS})
set_tests_properties(test_iterable_dataset set_tests_properties(test_iterable_dataset
......
...@@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto ...@@ -21,14 +21,17 @@ from paddle.distributed.fleet import auto
sys.path.append("..") sys.path.append("..")
import auto_parallel_gpt_model as modeling import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion from auto_parallel_gpt_model import (
GPTModel,
GPTForPretraining,
GPTPretrainingCriterion,
)
sequence_len = 512 sequence_len = 512
vocab_size = 1000 vocab_size = 1000
class FakeDataset(paddle.io.Dataset): class FakeDataset(paddle.io.Dataset):
def __init__(self, num_samples): def __init__(self, num_samples):
self.num_samples = num_samples self.num_samples = num_samples
self.sequence_len = sequence_len self.sequence_len = sequence_len
...@@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset): ...@@ -40,8 +43,11 @@ class FakeDataset(paddle.io.Dataset):
random.seed(2021) random.seed(2021)
tokens = np.random.randint(self.vocab_size, size=self.sequence_len) tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len) position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape( attention_mask = (
(1, self.sequence_len, self.sequence_len)).astype(np.float32) np.tril(np.ones(self.sequence_len))
.reshape((1, self.sequence_len, self.sequence_len))
.astype(np.float32)
)
labels = np.random.randint(self.vocab_size, size=self.sequence_len) labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32) loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask return tokens, position_ids, attention_mask, labels, loss_mask
...@@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset): ...@@ -51,30 +57,32 @@ class FakeDataset(paddle.io.Dataset):
def create_data_holder(batch_size): def create_data_holder(batch_size):
tokens = paddle.static.InputSpec(name="tokens", tokens = paddle.static.InputSpec(
shape=[batch_size, sequence_len], name="tokens", shape=[batch_size, sequence_len], dtype='int64'
dtype='int64') )
position_ids = paddle.static.InputSpec(name="position_ids", position_ids = paddle.static.InputSpec(
shape=[batch_size, sequence_len], name="position_ids", shape=[batch_size, sequence_len], dtype='int64'
dtype='int64') )
attention_mask = paddle.static.InputSpec( attention_mask = paddle.static.InputSpec(
name="attention_mask", name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len], shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32') dtype='float32',
labels = paddle.static.InputSpec(name="labels", )
shape=[batch_size, sequence_len], labels = paddle.static.InputSpec(
dtype='int64') name="labels", shape=[batch_size, sequence_len], dtype='int64'
loss_mask = paddle.static.InputSpec(name="loss_mask", )
shape=[batch_size, sequence_len], loss_mask = paddle.static.InputSpec(
dtype='float32') name="loss_mask", shape=[batch_size, sequence_len], dtype='float32'
)
return [tokens, position_ids, attention_mask], [labels, loss_mask] return [tokens, position_ids, attention_mask], [labels, loss_mask]
def generate_model(strategy): def generate_model(strategy, dropout_prob=0.0):
modeling.init_global() modeling.init_global()
ranks = list(range(paddle.distributed.get_world_size())) ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(mesh=ranks, modeling._global_process_mesh = auto.ProcessMesh(
dim_names=["x"]) mesh=ranks, dim_names=["x"]
)
if strategy == "serial": if strategy == "serial":
modeling._global_parallel_strategy = "serial" modeling._global_parallel_strategy = "serial"
elif strategy == "mp": elif strategy == "mp":
...@@ -84,24 +92,25 @@ def generate_model(strategy): ...@@ -84,24 +92,25 @@ def generate_model(strategy):
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
gpt = GPTModel(vocab_size=1000, gpt = GPTModel(
hidden_size=64, vocab_size=1000,
num_hidden_layers=2, hidden_size=64,
num_attention_heads=8, num_hidden_layers=2,
intermediate_size=256, num_attention_heads=8,
hidden_act="gelu", intermediate_size=256,
hidden_dropout_prob=0.0, hidden_act="gelu",
attention_probs_dropout_prob=0.0, hidden_dropout_prob=dropout_prob,
max_position_embeddings=1024, attention_probs_dropout_prob=dropout_prob,
type_vocab_size=1, max_position_embeddings=1024,
initializer_range=0.02, type_vocab_size=1,
pad_token_id=0, initializer_range=0.02,
eos_token_id=7, pad_token_id=0,
bos_token_id=0, eos_token_id=7,
eol_token_id=3) bos_token_id=0,
model = GPTForPretraining(gpt, eol_token_id=3,
vocab_size=1000, )
hidden_size=64, model = GPTForPretraining(
initializer_range=0.02) gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
)
criterion = GPTPretrainingCriterion() criterion = GPTPretrainingCriterion()
return model, criterion return model, criterion
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
paddle.enable_static()
from paddle import _legacy_C_ops
from paddle.distributed.fleet import auto
def dy_broadcast_helper(tensor):
_legacy_C_ops.c_broadcast(
tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 1000
)
_legacy_C_ops.c_sync_calc_stream(tensor, tensor)
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRandomControl(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
paddle.distributed.auto_parallel.parallel_manual_seed(100)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp", dropout_prob=0.1)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def compare_mask_between_ranks(
self, rank, mask_np_list, comapre_idx, equal
):
for np_mask in [mask_np_list[i] for i in comapre_idx]:
mask_tensor_local = paddle.to_tensor(np_mask.astype("float32"))
if rank == 0:
mask_tensor_remote = paddle.ones_like(mask_tensor_local)
dy_broadcast_helper(mask_tensor_remote)
if equal:
assert np.array_equal(
mask_tensor_remote.numpy(), mask_tensor_local.numpy()
)
else:
assert not np.array_equal(
mask_tensor_remote.numpy(),
mask_tensor_local.numpy(),
)
else:
dy_broadcast_helper(mask_tensor_local)
def test_random_ctrl_vanilla(self):
# mp2 recompute training
rc_engine = self.get_engine(False)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [outs['fetches'][varname] for varname in mask_name_list]
paddle.disable_static()
rank = paddle.distributed.get_rank()
# check globl mask consistent across ranks
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
rank = paddle.distributed.get_rank()
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'tensor_parallel_seed.tmp_0',
'tensor_parallel_seed.tmp_1',
'tensor_parallel_seed.tmp_2',
'tensor_parallel_seed.tmp_3',
'tensor_parallel_seed.tmp_4',
'tensor_parallel_seed.tmp_5',
'tensor_parallel_seed.tmp_6',
],
)
def test_random_ctrl_with_recompute(self):
# mp2 recompute training
rc_engine = self.get_engine(True)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
recompute_mask_name_list = [
'dropout_0.tmp_1.subprog_1',
'dropout_1.tmp_1.subprog_1',
'dropout_2.tmp_1.subprog_1',
'dropout_3.tmp_1.subprog_1',
'dropout_4.tmp_1.subprog_0',
'dropout_5.tmp_1.subprog_0',
'dropout_6.tmp_1.subprog_0',
]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list + recompute_mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [
outs['fetches'][varname]
for varname in mask_name_list + recompute_mask_name_list
]
# check recompute is mask the same within local device
for i in range(7):
mask_fw = mask_np_list[i].astype("float32")
mask_rc = mask_np_list[i + 7].astype("float32")
assert np.array_equal(
mask_fw,
mask_rc,
)
paddle.disable_static()
# check globl mask consistent across ranks
rank = paddle.distributed.get_rank()
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
rank = paddle.distributed.get_rank()
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestRandomCtrlPass(unittest.TestCase):
def test_mp2_with_recompute(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "random_control_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册