未验证 提交 8c7cb3d6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] engine.prepare only once (#43093)

* prepare only once
上级 7ba843e6
......@@ -343,8 +343,12 @@ class DistributedContext:
self._serial_startup_program = self._original_serial_startup_program
if not self._serial_loss:
if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1
self._serial_loss = self._original_serial_loss[0]
if len(self._original_serial_loss) == 1:
self._serial_loss = self._original_serial_loss[0]
elif len(self._original_serial_loss) == 0:
self._serial_loss = self._original_serial_loss
else:
raise ValueError("multi loss vars are not supported.")
else:
self._serial_loss = self._original_serial_loss
if not self._serial_optimizer:
......
......@@ -19,7 +19,7 @@ from collections import defaultdict
import paddle
import paddle.distributed.auto_parallel as auto
from paddle import fluid
from paddle import fluid, static
from paddle.io import Dataset
from paddle.metric import Metric
from paddle.static import InputSpec
......@@ -71,8 +71,8 @@ class Engine:
self._logger = get_logger(logging.INFO)
self._default_strategy = None
self._orig_main_prog = fluid.default_main_program()
self._orig_startup_prog = fluid.default_startup_program()
self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._serial_main_progs = {}
......@@ -87,28 +87,131 @@ class Engine:
loss=None,
gradient_scale=True,
metrics=None,
mode='train',
all_ranks=False):
if optimizer and not isinstance(optimizer, (
paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = optimizer
# TODO: check loss type
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics)
self._mode = mode
self._gradient_scale = gradient_scale
self._planned_mode = None
self._modes = ['train', 'eval', 'predict']
# Build forward program
self._build(mode)
# Do the planning process
planner = Planner(mode, self._dist_contexts[mode])
planner.plan()
self._build()
# Do auto parallel process
for mode in self._modes:
# Do the planning process
self._plan(mode)
# Do the parallel process
self._parallel(mode, all_ranks)
# Init comm and startup program
self._initialize(mode)
def _build(self):
for mode in self._modes:
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog):
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _plan(self, mode):
if self._planned_mode is None:
self._planned_mode = mode
else:
self._init_dist_context(mode)
self.planner = Planner(mode, self._dist_contexts[mode])
self.planner.plan()
def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, planner.completer,
parallelizer = Parallelizer(mode, self.planner.completer,
self._dist_contexts[mode])
if not all_ranks:
parallelizer.parallel(self._cur_rank)
else:
parallelizer.parallel_all()
def _init_dist_context(self, mode):
# Init dist_context['mode'] with the first planned dist_context
# to guarantee that train/eval/predict mode have same parallel strategy
dist_context = self._dist_contexts[mode]
origin_main_prog = dist_context._original_serial_main_program
ref_mode = self._planned_mode
ref_dist_context = self._dist_contexts[ref_mode]
ref_origin_main_prog = ref_dist_context._original_serial_main_program
ref_blocks = ref_origin_main_prog.blocks
for ib, block in enumerate(origin_main_prog.blocks):
for iop, op in enumerate(block.ops):
ref_op = ref_blocks[ib].ops[iop]
assert op.type == ref_op.type, \
"'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
ref_op)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode):
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program
......@@ -120,52 +223,7 @@ class Engine:
mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
# Init comm and startup program
self._initialize(mode)
def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with fluid.program_guard(serial_main_prog, serial_startup_prog):
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode):
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
......@@ -203,7 +261,7 @@ class Engine:
# TODO: evaluate after training
self.mode = 'train'
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare(mode='train')` first."
"train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch)
......@@ -227,16 +285,19 @@ class Engine:
return_numpy=True):
self.mode = 'eval'
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare(mode='eval')` first."
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
outputs = []
for step, data in enumerate(eval_dataloader):
logs, outs = self._eval_step(data, use_program_cache, return_numpy)
outputs.append(outs)
predict_logs = {"eval_" + name: val for name, val in logs.items()}
self._logger.info(predict_logs)
return outputs
eval_logs = dict()
outs = self._eval_step(data, use_program_cache, return_numpy)
eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
for metric in self._metrics:
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
self._logger.info(eval_logs)
return eval_logs
def predict(self,
test_data,
......@@ -245,7 +306,7 @@ class Engine:
return_numpy=True):
self.mode = 'predict'
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare(mode='predict')` first."
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
outputs = []
......@@ -262,57 +323,53 @@ class Engine:
def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars:
loss = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = None
else:
loss = self._executor.run(dist_main_prog,
fetch_list=to_list(fetch_var),
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_list = self._fetch_list(fetch_vars)
loss = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss
def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = self._fetch_vars[self.mode]["loss"][0]
if fetch_var.name not in dist_main_prog.global_block().vars:
outs = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["loss"] = outs
else:
outs = self._executor.run(dist_main_prog,
fetch_list=fetch_var,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = outs
return logs, outs
metrics = self._fetch_vars[self.mode]["metrics"]
losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses)
fetch_metrics = self._fetch_list(metrics)
fetch_list = fetch_loss + fetch_metrics
res = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
if not res[len(fetch_loss):]:
return res[:len(fetch_loss)]
for metric in self._metrics:
metric.update(*res[len(fetch_loss):])
return res[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
fetch_var = []
for var in self._fetch_vars[self.mode]["outputs"]:
if var.name in dist_main_prog.global_block().vars:
fetch_var.append(var)
if fetch_var is []:
outs = self._executor.run(dist_main_prog,
use_program_cache=use_program_cache)
logs["pred"] = outs
else:
outs = self._executor.run(dist_main_prog,
fetch_list=fetch_var,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_list = self._fetch_list(fetch_vars)
outs = self._executor.run(self.main_program,
fetch_list=fetch_list,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
return logs, outs
def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
return fetch_list
def _create_dataloader(self,
dataset,
batch_size,
......@@ -323,7 +380,9 @@ class Engine:
dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block()
# get feed_list from dist_program
# NOTE: Get feed_list from dist_program, then insert dataloader op
# with sharded var shape. Because predict_program does not contain
# labels var, so we will filter dataset's value with length of feed_list.
inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"]
feed_list = []
......@@ -342,7 +401,7 @@ class Engine:
# insert read op at the end of program
places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog):
with static.program_guard(dist_main_prog, dist_startup_prog):
dataloader = NonIterableGeneratorLoader(
dataset,
feed_list,
......@@ -468,10 +527,6 @@ class Engine:
def mode(self, mode):
self._mode = mode
@property
def metrics(self):
return self._metrics
@property
def main_program(self):
return self._dist_main_progs[self.mode][self._cur_rank]
......
......@@ -107,7 +107,6 @@ def train():
epsilon=1e-08,
grad_clip=None)
dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label')
......@@ -119,23 +118,29 @@ def train():
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine(
mlp,
inputs_spec=inputs_spec,
labels_spec=labels_spec,
strategy=dist_strategy)
engine.prepare(optimizer, loss)
engine.fit(dataset,
engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size,
steps_per_epoch=batch_num * batch_size)
# eval
eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval')
engine.evaluate(eval_dataset, batch_size)
# predict
test_dataset = MyDataset(batch_size)
engine.prepare(mode='predict')
engine.predict(test_dataset, batch_size)
# save
engine.save('./mlp_inf', training=False, mode='predict')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册