未验证 提交 3649099f 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add collate_fn for dist_loader (#45053)

* add collate_fn

* fix number of inputs
上级 8788513b
...@@ -1300,6 +1300,10 @@ class Completer: ...@@ -1300,6 +1300,10 @@ class Completer:
def complete_update_annotation(self, serial_main_program): def complete_update_annotation(self, serial_main_program):
"""Complete the annotation of vars and ops in the update phase for parallel program.""" """Complete the annotation of vars and ops in the update phase for parallel program."""
# Copy the dist tensors and dist ops annotated by users from the default context
# global mesh
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_ranks = get_world_process_group().ranks
# Notice: serial_main_program is actually a dist_main_program of current rank, # Notice: serial_main_program is actually a dist_main_program of current rank,
# and must be passed into this function. # and must be passed into this function.
...@@ -1371,7 +1375,7 @@ class Completer: ...@@ -1371,7 +1375,7 @@ class Completer:
if not learning_rate_completed: if not learning_rate_completed:
learning_rate_completed = True learning_rate_completed = True
var_dist_attr = TensorDistributedAttribute() var_dist_attr = TensorDistributedAttribute()
var_dist_attr.process_mesh = ref_process_mesh var_dist_attr.process_mesh = world_ranks
var_dist_attr.dims_mapping = [-1] var_dist_attr.dims_mapping = [-1]
self._dist_context.set_tensor_dist_attr_for_program( self._dist_context.set_tensor_dist_attr_for_program(
learning_var, var_dist_attr) learning_var, var_dist_attr)
......
...@@ -17,7 +17,8 @@ import numpy as np ...@@ -17,7 +17,8 @@ import numpy as np
import paddle import paddle
from .utils import to_list from .utils import to_list
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.io import DataLoader, DistributedBatchSampler from paddle.io import DataLoader, BatchSampler, IterableDataset
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
class DistributedDataLoader(metaclass=abc.ABCMeta): class DistributedDataLoader(metaclass=abc.ABCMeta):
...@@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -29,14 +30,32 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
data_parallel_world_size=None, data_parallel_world_size=None,
data_parallel_rank=None, data_parallel_rank=None,
drop_last=False): drop_last=False):
if isinstance(dataset, IterableDataset):
raise TypeError("IterableDataset is not supported.")
else:
self.dataset_kind = _DatasetKind.MAP
self.dataset = dataset self.dataset = dataset
self.batch_size = batch_size
self.epochs = epochs self.epochs = epochs
self.data_parallel_world_size = data_parallel_world_size
self.data_parallel_rank = data_parallel_rank
self.drop_lost = drop_last self.drop_lost = drop_last
if data_parallel_world_size is not None and batch_size is not None:
assert batch_size % data_parallel_world_size == 0 if batch_size is None:
self.batch_size = None
self.batch_sampler = None
else:
if data_parallel_world_size is not None:
assert batch_size % data_parallel_world_size == 0, \
"'batch_size' must be divisible by data parallel size"
self.batch_size = batch_size
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
...@@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -46,6 +65,16 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
def __next__(self): def __next__(self):
raise NotImplementedError raise NotImplementedError
@property
def index_sampler(self):
if self.auto_collate_batch:
return self.batch_sampler
else:
if self.dataset_kind == _DatasetKind.MAP:
return list(range(len(self.dataset)))
else:
raise TypeError("Only support datasets in map-style.")
class NonIterableGeneratorLoader(DistributedDataLoader): class NonIterableGeneratorLoader(DistributedDataLoader):
...@@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -56,21 +85,29 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
batch_size=1, batch_size=1,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None,
data_parallel_world_size=None, data_parallel_world_size=None,
data_parallel_rank=None, data_parallel_rank=None,
drop_last=False): drop_last=False):
self.feed_list = feed_list self.feed_list = feed_list
self.places = places self.places = places
self.steps_per_epoch = steps_per_epoch self.steps_per_epoch = steps_per_epoch
self.dp_world_size = 1 if data_parallel_world_size is None else data_parallel_world_size
self.dp_rank = 0 if data_parallel_rank is None else data_parallel_rank
super(NonIterableGeneratorLoader, super(NonIterableGeneratorLoader,
self).__init__(dataset, batch_size, epochs, self).__init__(dataset, batch_size, epochs,
data_parallel_world_size, data_parallel_rank, data_parallel_world_size, data_parallel_rank,
drop_last) drop_last)
self._inner_dataloader = self._create_inner_dataloader()
if self.auto_collate_batch:
self.collate_fn = collate_fn or default_collate_fn
else:
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self): def __iter__(self):
self._cur_step = 0 self._cur_step = 0
...@@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -101,31 +138,25 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def sample_data_generator(): def sample_data_generator():
batch_data = None for indices in self.sampler_iter:
for step, data in enumerate(self.dataset): assert len(indices) % self.dp_world_size == 0, \
data = flatten(data) "Please set batch_size to be divisible by data parallel size"
if batch_data is None: n = len(indices) // self.dp_world_size
batch_data = [[] for i in range(len(data))] cur_indices = [
for idx in range(len(data)): indices[i:i + n] for i in range(0, len(indices), n)
batch_data[idx].append(data[idx]) ]
if (step + 1) % self.batch_size == 0: batch = self.dataset_fetcher.fetch(cur_indices[self.dp_rank])
partial_data = [] yield batch[:len(self.feed_list)]
for d in batch_data:
array = np.array(d)
partial_data.append(
np.split(array, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)]
batch_data = None
def batch_data_generator(): def batch_data_generator():
for data in self.dataset: for indices in self.sampler_iter:
data = flatten(data)
partial_data = [] partial_data = []
for d in data: batch = self.dataset_fetcher.fetch(indices)
assert d.shape[0] % self.dp_world_size == 0, \ for data in batch:
"Please padding dataset with data parallel size" assert data.shape[0] % self.dp_world_size == 0, \
"Please padding dataset's batch_size to be divisible by data parallel size"
partial_data.append( partial_data.append(
np.split(d, self.dp_world_size)[self.dp_rank]) np.split(data, self.dp_world_size)[self.dp_rank])
yield partial_data[:len(self.feed_list)] yield partial_data[:len(self.feed_list)]
dataloader = paddle.fluid.io.DataLoader.from_generator( dataloader = paddle.fluid.io.DataLoader.from_generator(
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import time
import copy import copy
import logging import logging
from collections import defaultdict from collections import defaultdict
...@@ -306,6 +307,7 @@ class Engine: ...@@ -306,6 +307,7 @@ class Engine:
mode].dist_startup_programs mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode].serial_optimizer
if self._nranks > 1: if self._nranks > 1:
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
...@@ -403,7 +405,8 @@ class Engine: ...@@ -403,7 +405,8 @@ class Engine:
epochs=1, epochs=1,
fetches=None, fetches=None,
steps_per_epoch=None, steps_per_epoch=None,
use_program_cache=False, collate_fn=None,
use_cache=False,
return_numpy=True): return_numpy=True):
# TODO: callbacks # TODO: callbacks
# TODO: evaluate after training # TODO: evaluate after training
...@@ -417,18 +420,24 @@ class Engine: ...@@ -417,18 +420,24 @@ class Engine:
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch) epochs, steps_per_epoch,
collate_fn)
usr_fetch = self._validate_fetches(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch) fetch_list, fetch_map = self._fetch_map(fetch_loss, usr_fetch)
lr_scheduler = self.get_lr_scheduler(self.main_program)
for epoch in range(epochs): for epoch in range(epochs):
train_logs = {"epoch": epoch} train_logs = {"epoch": epoch}
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
if lr_scheduler is not None:
lr_scheduler.step()
train_logs["lr"] = self._optimizer.get_lr()
train_logs["step"] = step train_logs["step"] = step
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
...@@ -444,7 +453,8 @@ class Engine: ...@@ -444,7 +453,8 @@ class Engine:
eval_data, eval_data,
batch_size=1, batch_size=1,
fetches=None, fetches=None,
use_program_cache=False, collate_fn=None,
use_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'eval' self.mode = 'eval'
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
...@@ -452,7 +462,9 @@ class Engine: ...@@ -452,7 +462,9 @@ class Engine:
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data,
batch_size,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
...@@ -464,7 +476,7 @@ class Engine: ...@@ -464,7 +476,7 @@ class Engine:
eval_logs = {"step": step} eval_logs = {"step": step}
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
# inner fetches # inner fetches
if fetch_loss: if fetch_loss:
...@@ -489,7 +501,8 @@ class Engine: ...@@ -489,7 +501,8 @@ class Engine:
test_data, test_data,
batch_size=1, batch_size=1,
fetches=None, fetches=None,
use_program_cache=False, collate_fn=None,
use_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'predict' self.mode = 'predict'
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
...@@ -497,7 +510,9 @@ class Engine: ...@@ -497,7 +510,9 @@ class Engine:
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data,
batch_size,
collate_fn=collate_fn)
usr_fetch = self._validate_fetches(fetches) usr_fetch = self._validate_fetches(fetches)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
...@@ -508,7 +523,7 @@ class Engine: ...@@ -508,7 +523,7 @@ class Engine:
predict_logs = {"step": step} predict_logs = {"step": step}
outs = self._executor.run(self.main_program, outs = self._executor.run(self.main_program,
fetch_list=fetch_list, fetch_list=fetch_list,
use_program_cache=use_program_cache, use_program_cache=use_cache,
return_numpy=return_numpy) return_numpy=return_numpy)
outputs.append(outs[:len(fetch_outputs)]) outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs): for i, out in enumerate(outs):
...@@ -521,7 +536,8 @@ class Engine: ...@@ -521,7 +536,8 @@ class Engine:
dataset, dataset,
batch_size, batch_size,
epochs=1, epochs=1,
steps_per_epoch=None): steps_per_epoch=None,
collate_fn=None):
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
...@@ -554,6 +570,7 @@ class Engine: ...@@ -554,6 +570,7 @@ class Engine:
batch_size, batch_size,
epochs, epochs,
steps_per_epoch, steps_per_epoch,
collate_fn,
data_parallel_world_size=self._input_split_size, data_parallel_world_size=self._input_split_size,
data_parallel_rank=self._input_split_rank) data_parallel_rank=self._input_split_rank)
...@@ -645,12 +662,11 @@ class Engine: ...@@ -645,12 +662,11 @@ class Engine:
config = self.strategy.recompute_configs config = self.strategy.recompute_configs
# extract ckpts by specific model # extract ckpts by specific model
self.model
if isinstance(self.model, paddle.nn.Layer): if isinstance(self.model, paddle.nn.Layer):
if hasattr( if hasattr(
self.model, "model" self.model, "gpt"
) and self.model.model.__class__.__name__ == 'GPTForPretraining': ) and self.model.__class__.__name__ == 'GPTForPretraining':
exact_ckpts = self.model.model.gpt.checkpoints exact_ckpts = self.model.gpt.checkpoints
else: else:
exact_ckpts = config["checkpoints"] exact_ckpts = config["checkpoints"]
...@@ -659,7 +675,7 @@ class Engine: ...@@ -659,7 +675,7 @@ class Engine:
config["checkpoints"] = exact_ckpts[:] config["checkpoints"] = exact_ckpts[:]
self.strategy.recompute_configs = config self.strategy.recompute_configs = config
logs = { logs = {
'Model Class': self.model.model.__class__.__name__, 'Model Class': self.model.__class__.__name__,
'Applied Recompute ckpts': exact_ckpts 'Applied Recompute ckpts': exact_ckpts
} }
self._logger.info(logs) self._logger.info(logs)
...@@ -699,6 +715,15 @@ class Engine: ...@@ -699,6 +715,15 @@ class Engine:
self._saver.load(path, dist_main_prog, dist_context, strict, self._saver.load(path, dist_main_prog, dist_context, strict,
load_optimizer) load_optimizer)
@staticmethod
def get_lr_scheduler(program):
lr_sheduler = None
if hasattr(program, 'lr_sheduler'):
from paddle.optimizer.lr import LRScheduler
lr_sheduler = program.lr_sheduler
assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler"
return lr_sheduler
@property @property
def mode(self): def mode(self):
return self._mode return self._mode
......
...@@ -149,6 +149,7 @@ class Parallelizer: ...@@ -149,6 +149,7 @@ class Parallelizer:
paddle.enable_static() paddle.enable_static()
else: else:
optimizer = copy.deepcopy(optimizer) optimizer = copy.deepcopy(optimizer)
self._dist_context._serial_optimizer = optimizer
with program_guard(main_program, startup_program): with program_guard(main_program, startup_program):
optimizer_ops = optimizer.apply_gradients(params_grads) optimizer_ops = optimizer.apply_gradients(params_grads)
self._completer.complete_update_annotation(main_program) self._completer.complete_update_annotation(main_program)
......
...@@ -363,11 +363,15 @@ class OptimizationTuner: ...@@ -363,11 +363,15 @@ class OptimizationTuner:
profile_args = " ".join([ profile_args = " ".join([
"--rank", "--rank",
str(self.rank), "--device_id", str(self.rank),
str(self.device_id), "--ctx_filename", ctx_path, "--device_id",
str(self.device_id),
"--ctx_filename",
ctx_path,
"--profile_start_step", "--profile_start_step",
str(self._config.profile_start_step), "--profile_end_step", str(self._config.profile_start_step),
str(self._config.profile_end_step) "--profile_end_step",
str(self._config.profile_end_step),
]) ])
cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args
cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args) cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args)
......
...@@ -31,6 +31,8 @@ from paddle.static import InputSpec ...@@ -31,6 +31,8 @@ from paddle.static import InputSpec
from paddle.distributed import fleet from paddle.distributed import fleet
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine from paddle.distributed.auto_parallel.engine import Engine
from paddle.optimizer.lr import CosineAnnealingDecay
from paddle.fluid.dataloader.collate import default_collate_fn
paddle.enable_static() paddle.enable_static()
global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
...@@ -106,19 +108,18 @@ def train(fetch): ...@@ -106,19 +108,18 @@ def train(fetch):
dropout_ratio=0.1, dropout_ratio=0.1,
initializer_range=0.02) initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss() loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.00001,
beta1=0.9, T_max=10)
beta2=0.999, optimizer = paddle.optimizer.Adam(learning_rate=scheduler,
epsilon=1e-08, beta1=0.9,
grad_clip=None) beta2=0.999,
epsilon=1e-08,
grad_clip=None)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label') labels_spec = InputSpec([batch_size], 'int64', 'label')
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = False
dist_strategy.pipeline = False
dist_strategy.recompute = False
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
......
...@@ -145,7 +145,7 @@ def main(): ...@@ -145,7 +145,7 @@ def main():
labels_spec=labels_spec, labels_spec=labels_spec,
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer=optimizer, loss=loss_func) engine.prepare(optimizer=optimizer, loss=loss_func)
res = engine.fit(train_dataset, batch_size=None) engine.fit(train_dataset, batch_size=None)
dist_context = engine.dist_context dist_context = engine.dist_context
block = engine.main_program.global_block() block = engine.main_program.global_block()
......
...@@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase): ...@@ -282,13 +282,16 @@ class TestMLPReshard(unittest.TestCase):
if op.type == "gelu_grad": if op.type == "gelu_grad":
op_need_check = op op_need_check = op
break break
# print_program_with_dist_attr(dist_main_prog, dist_context)
# grad op should have dist attr # grad op should have dist attr
self.assertTrue( self.assertTrue(
check_backward_dist_attr(dist_context, dist_main_prog, check_backward_dist_attr(dist_context, dist_main_prog,
op_need_check)) op_need_check))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_pp(self): def test_mlp_pp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "pp" _global_parallel_strategy = "pp"
...@@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase): ...@@ -305,29 +308,35 @@ class TestMLPReshard(unittest.TestCase):
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id) train_program, startup_program, dist_context, rank_id)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
# check send and recv result # check send and recv result
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
# parameter initialization of every rank should be different in the pipeline scene # parameter initialization of every rank should be different in the pipeline scene
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_pp_diff_process_mesh(self): def test_mlp_pp_diff_process_mesh(self):
global _global_parallel_strategy
_global_parallel_strategy = "pp"
global _global_process_mesh
_global_process_mesh = auto.ProcessMesh(mesh=[0, 1])
global PP_MESH_0
PP_MESH_0 = auto.ProcessMesh(mesh=[0])
global PP_MESH_1
PP_MESH_1 = auto.ProcessMesh(mesh=[1])
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_context = DistributedContext() dist_context = DistributedContext()
rank_id = 1 rank_id = 1
dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog(
train_program, startup_program, dist_context, rank_id, True) train_program, startup_program, dist_context, rank_id, True)
for key in list(_g_process_group_map.keys()):
del _g_process_group_map[key]
_g_process_group_map[0] = ProcessGroup(0, [])
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
...@@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase): ...@@ -335,6 +344,10 @@ class TestMLPReshard(unittest.TestCase):
self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) self.assertTrue(check_send_recv_result(dist_main_prog, rank_id))
self.assertTrue(check_initialization(dist_startup_prog, rank_id)) self.assertTrue(check_initialization(dist_startup_prog, rank_id))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
def test_mlp_dp(self): def test_mlp_dp(self):
global _global_parallel_strategy global _global_parallel_strategy
_global_parallel_strategy = "dp" _global_parallel_strategy = "dp"
...@@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase): ...@@ -350,12 +363,16 @@ class TestMLPReshard(unittest.TestCase):
resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id,
dist_context, dist_params_grads) dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
# send and recv should not exist in dp scene. # send and recv should not exist in dp scene.
self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) self.assertFalse(check_send_recv_result(dist_main_prog, rank_id))
# all parameters should be initialized in dp scene # all parameters should be initialized in dp scene
self.assertTrue(check_initialization_for_dp(dist_startup_prog)) self.assertTrue(check_initialization_for_dp(dist_startup_prog))
# clear _g_process_group_map
_g_process_group_map.clear()
_g_process_group_map[0] = ProcessGroup(0, [])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册