未验证 提交 03afb41c 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Randomness Control for Distributed Training (#52554)

* unique id for mesh

* rng ctrl

* support dropout

* register op

* adopt for recompute

* update unitest

* support pp
上级 349a059d
...@@ -19,5 +19,6 @@ from .interface import shard_tensor ...@@ -19,5 +19,6 @@ from .interface import shard_tensor
from .interface import shard_op from .interface import shard_op
from .interface import recompute from .interface import recompute
from .interface import fetch from .interface import fetch
from .random import parallel_manual_seed
__all__ = [] __all__ = []
...@@ -45,7 +45,7 @@ from .dist_loader import ( ...@@ -45,7 +45,7 @@ from .dist_loader import (
from .dist_op import DistributedOperator from .dist_op import DistributedOperator
from .dist_saver import DistributedSaver from .dist_saver import DistributedSaver
from .helper import ProgramHelper from .helper import ProgramHelper
from .interface import CollectionNames, get_collection from .interface import CollectionNames, fetch, get_collection
from .parallelizer_v2 import Parallelizer from .parallelizer_v2 import Parallelizer
from .planner_v2 import Planner from .planner_v2 import Planner
from .process_group import get_all_process_groups, new_process_group from .process_group import get_all_process_groups, new_process_group
...@@ -410,6 +410,8 @@ class Engine: ...@@ -410,6 +410,8 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -434,10 +436,13 @@ class Engine: ...@@ -434,10 +436,13 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
for usr_fetch in user_fetches:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
item[1] for item in get_collection(CollectionNames.FETCHES) item[1] for item in get_collection(CollectionNames.FETCHES)
] ]
var_list = (user_fetches_collection or []) + (user_fetches or []) var_list = user_fetches_collection or []
_process_fetch_group("fetches", var_list) _process_fetch_group("fetches", var_list)
return fetch_names, fetch_indices return fetch_names, fetch_indices
......
...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p ...@@ -36,3 +36,4 @@ from . import dist_reduce_sum_p
from . import dist_shape from . import dist_shape
from . import dist_assign from . import dist_assign
from . import dist_scale from . import dist_scale
from . import dist_dropout
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
import logging
import paddle
from paddle.framework import core
from paddle.utils import unique_name
from ...utils.log_utils import get_logger
_logger = get_logger(logging.INFO)
from ..random import determinate_rng, is_enable_auto_rand_ctrl
from ..utils import (
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
)
from .common import (
DistributedOperatorImplContainer,
register_distributed_operator_impl,
register_distributed_operator_impl_container,
)
from .dist_eltwise import DistributedDefaultImpl0, DistributedElementwiseImpl0
class DistributedDropout(DistributedOperatorImplContainer):
def __init__(self, op_type):
super().__init__(op_type)
register_distributed_operator_impl_container(DistributedDropout("dropout"))
# Dist Dropout with Random Control
# Dropout re-use the compatible and cost function of elementwise
class DistributedDropoutImpl0(DistributedElementwiseImpl0):
def __init__(self, name):
super().__init__(name)
self._forward_implemented = True
self._backward_implemented = True
@staticmethod
def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context
main_block = dist_op_context.work_block
startup_block = dist_op_context.startup_block
src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
if is_enable_auto_rand_ctrl() and not op_dist_attr.is_recompute:
assert (
op_dist_attr is not None
), f"forward op [{str(src_op)}] don't have dist attribute !"
# check validation of inputs / outputs
assert 'X' in kwargs, "input [{}] is not given".format('X')
assert (
len(kwargs['X']) == 1
), "input X should be only one tensor but got {}".format(
kwargs['X']
)
assert 'Seed' in kwargs, "input [{}] is not given".format('Seed')
if (
src_op.has_attr("fix_seed")
and src_op.attr("fix_seed")
and src_op.has_attr("seed")
and src_op.attr("seed")
):
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
elif rank_id not in op_dist_attr.process_mesh.process_ids:
pass
# NOTE Adopt for recompute
# If user already set seed, We should not modify it. But if the seed is added by recompute pass, it should be under control.
# TODO in future recompute pass should happen after parallel partitione. and remove this at that time.
elif len(kwargs['Seed']) > 0 or len(src_op.input("Seed")) > 0:
seed_var_name = kwargs['Seed'][0]
if seed_var_name.startswith('rc_seed'):
pre_op = main_block.ops[-1]
assert (
pre_op.type == "seed"
and len(pre_op.attr("rng_name")) == 0
), f"found exception op {str(pre_op)}"
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(
X_var.name
)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
# make recompute seed under control
pre_op._set_attr("rng_name", rng_name)
pre_op._set_attr("deterministic", True)
pre_op._set_attr("force_cpu", True)
else:
_logger.info(
"Auto Parallel Random Control Skiped Since manul seed is set by user: {}".format(
src_op
)
)
else:
# determinate rng
X_var = main_block._var_recursive(kwargs['X'][0])
X_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name)
process_mesh = op_dist_attr.process_mesh
rng_name = determinate_rng(
rank_id, X_dims_mapping, process_mesh
)
assert rng_name is not None and rng_name != ""
# insert seed op
seed_var = main_block.create_var(
name=unique_name.generate_with_ignorable_key(
".".join(["tensor_parallel_seed", 'tmp'])
),
dtype=paddle.int32,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=False,
)
# set new seed_var's dist_attr
seed_var_dims_mapping = [-1]
seed_var_dist_attr = set_var_dist_attr(
ctx, seed_var, seed_var_dims_mapping, process_mesh
)
# adopt for recompute
# force_cpu to reduce sync copy from CPU->GPU->CPU, and reduce pipeline hang
seed_op = main_block.append_op(
type='seed',
outputs={'Out': seed_var},
attrs={
'deterministic': True,
'rng_name': rng_name,
'force_cpu': True,
},
)
seed_op._set_attr('op_namescope', 'auto_tensor_parallel_seed')
# set new seed op's dist_attr
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
seed_op, process_mesh, seed_var_dims_mapping, ctx
)
# modify dropout op
src_op.desc.set_input("Seed", [seed_var.name])
src_op._remove_attr("fix_seed")
src_op._remove_attr("seed")
op_dist_attr.set_input_dist_attr(
seed_var.name, seed_var_dist_attr
)
kwargs['Seed'] = [seed_var.name]
DistributedDefaultImpl0.forward(ctx, *args, **kwargs)
@staticmethod
def backward(ctx, *args, **kwargs):
# dropout backward is deterministic by mask, and not need for random state control
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
register_distributed_operator_impl(
"dropout", DistributedDropoutImpl0("random_control")
)
...@@ -362,7 +362,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -362,7 +362,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert ( assert (
op_dist_attr is not None op_dist_attr is not None
), f"backward op [{str(src_op)}] don't have dist attribute !" ), f"forward op [{str(src_op)}] don't have dist attribute !"
# check validation of inputs / outputs # check validation of inputs / outputs
assert 'Ids' in kwargs, "input [{}] is not given".format('Ids') assert 'Ids' in kwargs, "input [{}] is not given".format('Ids')
......
...@@ -23,6 +23,7 @@ from paddle.utils import unique_name ...@@ -23,6 +23,7 @@ from paddle.utils import unique_name
from ..utils.log_utils import get_logger from ..utils.log_utils import get_logger
from .partitioner import Partitioner from .partitioner import Partitioner
from .process_group import get_world_process_group from .process_group import get_world_process_group
from .random import init_auto_parallel_rng
from .reshard import Resharder from .reshard import Resharder
from .utils import set_grad_var_shape from .utils import set_grad_var_shape
...@@ -83,6 +84,9 @@ class Parallelizer: ...@@ -83,6 +84,9 @@ class Parallelizer:
) = partitioner.partition( ) = partitioner.partition(
serial_main_program, serial_startup_program, params_grads serial_main_program, serial_startup_program, params_grads
) )
init_auto_parallel_rng()
self._logger.debug( self._logger.debug(
"within parallel partitioner time: {}, mode {}".format( "within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode time.time() - time0, self._mode
......
...@@ -22,6 +22,8 @@ from paddle.framework import core ...@@ -22,6 +22,8 @@ from paddle.framework import core
# Use to store the previous and current process mesh # Use to store the previous and current process mesh
_g_previous_process_mesh = None _g_previous_process_mesh = None
_g_current_process_mesh = None _g_current_process_mesh = None
# {shape_process_ids : unique_id}
_g_unique_process_mesh_map = {}
def get_current_process_mesh(): def get_current_process_mesh():
...@@ -42,6 +44,30 @@ def reset_current_process_mesh(): ...@@ -42,6 +44,30 @@ def reset_current_process_mesh():
_g_current_process_mesh = _g_previous_process_mesh _g_current_process_mesh = _g_previous_process_mesh
def get_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
if key in _g_unique_process_mesh_map:
unique_id = _g_unique_process_mesh_map[key]
else:
unique_id = len(_g_unique_process_mesh_map) + 1
_g_unique_process_mesh_map[key] = unique_id
return unique_id
def retrive_unique_id_for_process_mesh(shape, process_ids):
key = f"shape {shape}, process_ids {process_ids}"
global _g_unique_process_mesh_map
assert key in _g_unique_process_mesh_map
return _g_unique_process_mesh_map[key]
def get_unique_process_mesh_map():
global _g_unique_process_mesh_map
return _g_unique_process_mesh_map
class ProcessMesh(core.ProcessMesh): class ProcessMesh(core.ProcessMesh):
""" """
The `ProcessMesh` object describes the Cartesian topology of the used processes. The `ProcessMesh` object describes the Cartesian topology of the used processes.
...@@ -124,6 +150,11 @@ class ProcessMesh(core.ProcessMesh): ...@@ -124,6 +150,11 @@ class ProcessMesh(core.ProcessMesh):
pg0 = get_process_group(0) pg0 = get_process_group(0)
pg0.add_ranks(self.process_ids) pg0.add_ranks(self.process_ids)
# Uniqe Mesh Id
self._unique_id = get_unique_id_for_process_mesh(
self._shape, self._process_ids
)
@property @property
def mesh(self): def mesh(self):
""" """
...@@ -131,6 +162,16 @@ class ProcessMesh(core.ProcessMesh): ...@@ -131,6 +162,16 @@ class ProcessMesh(core.ProcessMesh):
""" """
return self._mesh return self._mesh
@property
def unique_id(self):
"""
Get the unique id of ProcessMesh.
NOTE
Unique id only take process_ids and shape into account.
Different ProcessMesh with same process_ids and shape have same unique id.
"""
return self._unique_id
def __getitem__(self, index): def __getitem__(self, index):
if isinstance(index, tuple): if isinstance(index, tuple):
new_dim_names = [] new_dim_names = []
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import paddle
from ..utils.log_utils import get_logger
from .process_mesh import retrive_unique_id_for_process_mesh
from .utils import _get_idx_in_axis
_logger = get_logger(logging.INFO)
_rng_name_to_seed = {}
_inited_rng_name_to_seed = {}
_enable_random_control = False
_basic_seed = 42
# use Prime number as offset to avoid confict
_mesh_offset = 173
_dim_offsets = [11, 23, 37, 73]
def is_enable_auto_rand_ctrl():
global _enable_random_control
return _enable_random_control
def enable_auto_rand_ctrl():
global _enable_random_control
_enable_random_control = True
def parallel_manual_seed(seed):
"""Enable auto parallel random control.
Random control maintain the randomness when tensor is distributed across devices on a Mesh(any order).
* Independency: If tensor is **Sharded** on a Mesh dimension, Devices along that Mesh dimension should have Different randomness.
* Consistency: Meanwhile if the tensor is **Replicated** on another Mesh dimension, randomness of Devices along that Mesh dimension should be Consistent.
For instance: rank0 ~ rank7 consist a Mesh of shape of [2, 4]; A 2D tensor is distributed in that Mesh using dims_mapping [-1, 1].
Randomness for rank0-rank1-rank2-rank3 (rank4-rank5-rank6-rank7) should be Independent;
Randomness for rank0 and rank4 (rank1 and rank5, ...) should be Consistent.
This function should be called only once before auto parallel compiles the computation graph (e.g. auto_parallel.engine.prepare() or fit()).
This seed only affects how randomness-relative **operators** (dropout, fuse op with dropout inside, etc) are execute amonge mesh, and would NOT affect other processe like Parameter initialization.
Examples:
# seed relative to training step
auto_parallel_random_seed((step + 13) * 257)
...
engine.prepare()
"""
enable_auto_rand_ctrl()
global _basic_seed
_basic_seed = seed
def determinate_rng(rank, dims_mapping, process_mesh):
# TODO(JZ-LIANG) Support Mesh with any high rank
# use a string to unique integer hashing algorithm for seed computation.
# instead of using offsets to coodinate seed across devices.
if len(process_mesh.shape) > 4:
raise NotImplementedError(
"Auto Parallel Random Control for Mesh's rank > 4 is NOT supported! Got {}".format(
str(process_mesh)
)
)
global _basic_seed
seed_ = _basic_seed
# FIXME
# unique_id = process_mesh.unique_id
unique_id = retrive_unique_id_for_process_mesh(
process_mesh.shape, process_mesh.process_ids
)
sharding_expr = f'mesh:{unique_id}'
seed_ += _mesh_offset * (unique_id + 1)
for i in range(len(process_mesh.shape)):
if i not in dims_mapping:
relative_idx = -1
else:
relative_idx = _get_idx_in_axis(
process_mesh.process_ids,
process_mesh.shape,
i,
rank,
)
sharding_expr += f"_dim{i}:{relative_idx}"
seed_ += _dim_offsets[i] * (relative_idx + 1)
global _rng_name_to_seed
if sharding_expr in _rng_name_to_seed:
assert _rng_name_to_seed[sharding_expr] == seed_
else:
assert (
seed_ not in _rng_name_to_seed.values()
), "Seed Confilt! current seed: {}, current sharding expr: {}, generated seed: {}".format(
seed_, sharding_expr, _rng_name_to_seed
)
_rng_name_to_seed[sharding_expr] = seed_
return sharding_expr
def init_auto_parallel_rng():
if not is_enable_auto_rand_ctrl():
return
global _rng_name_to_seed
# NOTE init rng maybe call multiple times, avoid init same rng twice
global _inited_rng_name_to_seed
for rng_name, seed in _rng_name_to_seed.items():
if rng_name in _inited_rng_name_to_seed:
assert _inited_rng_name_to_seed[rng_name] == seed
else:
_logger.info(
f"Init Auto Parallel RNG: {rng_name}, with seed {seed}"
)
paddle.framework.random.set_random_seed_generator(rng_name, seed)
_inited_rng_name_to_seed[rng_name] = seed
...@@ -136,7 +136,9 @@ class RecomputeState(ProgramStats): ...@@ -136,7 +136,9 @@ class RecomputeState(ProgramStats):
cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op) cur_op_dist_attr = dist_context.get_op_dist_attr_for_program(cur_op)
# insert seed op to guarantee that two dropout op have the same outputs # insert seed op to guarantee that two dropout op have the same outputs
op_unique_name = unique_name.generate("seed") # NOTE Hack for adopt recompute for random control, for more info see dist_dropout.py
# new seed added by recompute should have a prefix to distinguish with seed added by user or other moudule.
op_unique_name = unique_name.generate("rc_seed")
var_unique_name = unique_name.generate_with_ignorable_key( var_unique_name = unique_name.generate_with_ignorable_key(
".".join([op_unique_name, 'tmp']) ".".join([op_unique_name, 'tmp'])
) )
......
...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -37,6 +37,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip) py_test_modules(test_pass_grad_clip MODULES test_pass_grad_clip)
set_tests_properties(test_pass_grad_clip set_tests_properties(test_pass_grad_clip
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_random_ctrl MODULES test_random_ctrl)
set_tests_properties(test_random_ctrl PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge) py_test_modules(test_pass_gradient_merge MODULES test_pass_gradient_merge)
set_tests_properties(test_pass_gradient_merge set_tests_properties(test_pass_gradient_merge
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
......
...@@ -75,7 +75,7 @@ def create_data_holder(batch_size, vocab_size=1000, sequence_len=512): ...@@ -75,7 +75,7 @@ def create_data_holder(batch_size, vocab_size=1000, sequence_len=512):
return [tokens, position_ids, attention_mask], [labels, loss_mask] return [tokens, position_ids, attention_mask], [labels, loss_mask]
def generate_model(strategy): def generate_model(strategy, dropout_prob=0.0):
modeling.init_global() modeling.init_global()
ranks = list(range(paddle.distributed.get_world_size())) ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh( modeling._global_process_mesh = auto.ProcessMesh(
...@@ -97,8 +97,8 @@ def generate_model(strategy): ...@@ -97,8 +97,8 @@ def generate_model(strategy):
num_attention_heads=8, num_attention_heads=8,
intermediate_size=256, intermediate_size=256,
hidden_act="gelu", hidden_act="gelu",
hidden_dropout_prob=0.0, hidden_dropout_prob=dropout_prob,
attention_probs_dropout_prob=0.0, attention_probs_dropout_prob=dropout_prob,
max_position_embeddings=1024, max_position_embeddings=1024,
type_vocab_size=1, type_vocab_size=1,
initializer_range=0.02, initializer_range=0.02,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
paddle.enable_static()
from paddle import _legacy_C_ops
from paddle.distributed.fleet import auto
def dy_broadcast_helper(tensor):
_legacy_C_ops.c_broadcast(
tensor, tensor, 'root', 1, 'use_calc_stream', True, 'ring_id', 0
)
_legacy_C_ops.c_sync_calc_stream(tensor, tensor)
def apply_pass(use_recompute=False, no_recompute_segments=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_recompute:
recompute = strategy.recompute
recompute.enable = True
recompute.no_recompute_segments = no_recompute_segments
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestRandomControl(unittest.TestCase):
def setUp(self):
self.rtol = 1e-6
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
paddle.distributed.auto_parallel.parallel_manual_seed(100)
def init(self, engine):
paddle.seed(2022)
np.random.seed(2022)
random.seed(2022)
place = paddle.fluid.CUDAPlace(paddle.distributed.ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_recompute=False, no_recompute_segments=[]):
reset_prog()
strategy = apply_pass(use_recompute, no_recompute_segments)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("mp", dropout_prob=0.1)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def compare_mask_between_ranks(
self, rank, mask_np_list, comapre_idx, equal
):
for np_mask in [mask_np_list[i] for i in comapre_idx]:
mask_tensor_local = paddle.to_tensor(np_mask.astype("float32"))
if rank == 0:
mask_tensor_remote = paddle.ones_like(mask_tensor_local)
dy_broadcast_helper(mask_tensor_remote)
if equal:
assert np.array_equal(
mask_tensor_remote.numpy(), mask_tensor_local.numpy()
)
else:
assert not np.array_equal(
mask_tensor_remote.numpy(),
mask_tensor_local.numpy(),
)
else:
dy_broadcast_helper(mask_tensor_local)
def test_random_ctrl_vanilla(self):
# mp2 recompute training
rc_engine = self.get_engine(False)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [outs['fetches'][varname] for varname in mask_name_list]
paddle.disable_static()
rank = paddle.distributed.get_rank()
# check globl mask consistent across ranks
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
rank = paddle.distributed.get_rank()
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'tensor_parallel_seed.tmp_0',
'tensor_parallel_seed.tmp_1',
'tensor_parallel_seed.tmp_2',
'tensor_parallel_seed.tmp_3',
'tensor_parallel_seed.tmp_4',
'tensor_parallel_seed.tmp_5',
'tensor_parallel_seed.tmp_6',
],
)
def test_random_ctrl_with_recompute(self):
# mp2 recompute training
rc_engine = self.get_engine(True)
train_dataloader = rc_engine.dataloader(
self.dataset,
batch_size=self.batch_size,
mode="train",
sample_split=3,
)
rc_engine.prepare(mode="train")
mask_name_list = [f'dropout_{i}.tmp_1' for i in range(7)]
recompute_mask_name_list = [
'dropout_0.tmp_1.subprog_1',
'dropout_1.tmp_1.subprog_1',
'dropout_2.tmp_1.subprog_1',
'dropout_3.tmp_1.subprog_1',
'dropout_4.tmp_1.subprog_0',
'dropout_5.tmp_1.subprog_0',
'dropout_6.tmp_1.subprog_0',
]
mask_var_list = [
rc_engine.main_program.global_block().var(varname)
for varname in mask_name_list + recompute_mask_name_list
]
for data in train_dataloader:
outs = rc_engine.run(data, fetch_list=mask_var_list, mode="train")
mask_np_list = [
outs['fetches'][varname]
for varname in mask_name_list + recompute_mask_name_list
]
# check recompute is mask the same within local device
for i in range(7):
mask_fw = mask_np_list[i].astype("float32")
mask_rc = mask_np_list[i + 7].astype("float32")
assert np.array_equal(
mask_fw,
mask_rc,
)
paddle.disable_static()
# check globl mask consistent across ranks
rank = paddle.distributed.get_rank()
global_index = [0, 2, 3, 5, 6]
self.compare_mask_between_ranks(
rank, mask_np_list, global_index, equal=True
)
local_index = [1, 4]
# check loacl mask different across ranks
self.compare_mask_between_ranks(
rank, mask_np_list, local_index, equal=False
)
paddle.enable_static()
# check program
rank = paddle.distributed.get_rank()
ops = rc_engine.main_program.global_block().ops
rng_names = []
seed_var_names = []
for op in ops:
if op.type == "seed":
rng_names.append(op.attr('rng_name'))
if op.type == "dropout":
seed_var_names.append(op.input("Seed")[0])
self.assertEqual(
rng_names,
[
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
f'mesh:1_dim0:{rank}',
'mesh:1_dim0:-1',
'mesh:1_dim0:-1',
],
)
self.assertEqual(
seed_var_names,
[
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_4.tmp_0',
'rc_seed_5.tmp_0',
'rc_seed_6.tmp_0',
'rc_seed_0.tmp_0',
'rc_seed_1.tmp_0',
'rc_seed_2.tmp_0',
'rc_seed_3.tmp_0',
],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestRandomCtrlPass(unittest.TestCase):
def test_mp2_with_recompute(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "random_control_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册