未验证 提交 8c7cb3d6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] engine.prepare only once (#43093)

* prepare only once
上级 7ba843e6
...@@ -343,8 +343,12 @@ class DistributedContext: ...@@ -343,8 +343,12 @@ class DistributedContext:
self._serial_startup_program = self._original_serial_startup_program self._serial_startup_program = self._original_serial_startup_program
if not self._serial_loss: if not self._serial_loss:
if isinstance(self._original_serial_loss, list): if isinstance(self._original_serial_loss, list):
assert len(self._original_serial_loss) == 1 if len(self._original_serial_loss) == 1:
self._serial_loss = self._original_serial_loss[0] self._serial_loss = self._original_serial_loss[0]
elif len(self._original_serial_loss) == 0:
self._serial_loss = self._original_serial_loss
else:
raise ValueError("multi loss vars are not supported.")
else: else:
self._serial_loss = self._original_serial_loss self._serial_loss = self._original_serial_loss
if not self._serial_optimizer: if not self._serial_optimizer:
......
...@@ -19,7 +19,7 @@ from collections import defaultdict ...@@ -19,7 +19,7 @@ from collections import defaultdict
import paddle import paddle
import paddle.distributed.auto_parallel as auto import paddle.distributed.auto_parallel as auto
from paddle import fluid from paddle import fluid, static
from paddle.io import Dataset from paddle.io import Dataset
from paddle.metric import Metric from paddle.metric import Metric
from paddle.static import InputSpec from paddle.static import InputSpec
...@@ -71,8 +71,8 @@ class Engine: ...@@ -71,8 +71,8 @@ class Engine:
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
self._default_strategy = None self._default_strategy = None
self._orig_main_prog = fluid.default_main_program() self._orig_main_prog = static.default_main_program()
self._orig_startup_prog = fluid.default_startup_program() self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context() self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {} self._dist_contexts = {}
self._serial_main_progs = {} self._serial_main_progs = {}
...@@ -87,28 +87,131 @@ class Engine: ...@@ -87,28 +87,131 @@ class Engine:
loss=None, loss=None,
gradient_scale=True, gradient_scale=True,
metrics=None, metrics=None,
mode='train',
all_ranks=False): all_ranks=False):
if optimizer and not isinstance(optimizer, (
paddle.optimizer.Optimizer, paddle.fluid.optimizer.Optimizer)):
raise TypeError(
"'optimizer' must be object of class `paddle.optimizer.Optimizer`" \
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = optimizer self._optimizer = optimizer
# TODO: check loss type
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
raise TypeError(
"'loss' must be sub classes of `paddle.nn.Layer` or any callable function."
)
self._loss = loss self._loss = loss
metrics = metrics or []
for metric in to_list(metrics):
assert isinstance(metric, Metric), \
"{} is not sub class of Metric".format(
metric.__class__.__name__)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._mode = mode
self._gradient_scale = gradient_scale self._gradient_scale = gradient_scale
self._planned_mode = None
self._modes = ['train', 'eval', 'predict']
# Build forward program # Build forward program
self._build(mode) self._build()
# Do the planning process
planner = Planner(mode, self._dist_contexts[mode]) # Do auto parallel process
planner.plan() for mode in self._modes:
# Do the planning process
self._plan(mode)
# Do the parallel process
self._parallel(mode, all_ranks)
# Init comm and startup program
self._initialize(mode)
def _build(self):
for mode in self._modes:
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with static.program_guard(serial_main_prog, serial_startup_prog):
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
if mode != "predict":
for metric in self._metrics:
metrics.extend(
to_list(metric.compute(*(outputs + labels))))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _plan(self, mode):
if self._planned_mode is None:
self._planned_mode = mode
else:
self._init_dist_context(mode)
self.planner = Planner(mode, self._dist_contexts[mode])
self.planner.plan()
def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results # Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner, # For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update. # because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, planner.completer, parallelizer = Parallelizer(mode, self.planner.completer,
self._dist_contexts[mode]) self._dist_contexts[mode])
if not all_ranks: if not all_ranks:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank)
else: else:
parallelizer.parallel_all() parallelizer.parallel_all()
def _init_dist_context(self, mode):
# Init dist_context['mode'] with the first planned dist_context
# to guarantee that train/eval/predict mode have same parallel strategy
dist_context = self._dist_contexts[mode]
origin_main_prog = dist_context._original_serial_main_program
ref_mode = self._planned_mode
ref_dist_context = self._dist_contexts[ref_mode]
ref_origin_main_prog = ref_dist_context._original_serial_main_program
ref_blocks = ref_origin_main_prog.blocks
for ib, block in enumerate(origin_main_prog.blocks):
for iop, op in enumerate(block.ops):
ref_op = ref_blocks[ib].ops[iop]
assert op.type == ref_op.type, \
"'{}' mode op '{}' is different with '{}' op '{}'. ".format(mode, op.type, ref_mode, ref_op.type)
ref_op_dist_attr = ref_dist_context.get_op_dist_attr_for_program(
ref_op)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode):
# Get the current content from the distributed context # Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[ self._serial_main_progs[mode] = self._dist_contexts[
mode].serial_main_program mode].serial_main_program
...@@ -120,52 +223,7 @@ class Engine: ...@@ -120,52 +223,7 @@ class Engine:
mode].dist_startup_programs mode].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
# Init comm and startup program
self._initialize(mode)
def _build(self, mode):
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
return
losses = []
metrics = []
serial_main_prog = self._orig_main_prog.clone()
serial_startup_prog = self._orig_startup_prog.clone()
with fluid.program_guard(serial_main_prog, serial_startup_prog):
inputs_spec = self.inputs_spec
labels_spec = self.labels_spec if self.labels_spec else []
inputs = [s._create_feed_layer() for s in inputs_spec]
labels = [s._create_feed_layer() for s in labels_spec]
outputs = to_list(self.model(*inputs))
if mode != "predict" and self._loss:
losses = to_list(self._loss(*(outputs + labels)))
default_ctx = get_default_distributed_context()
if not default_ctx.has_annotation or self._default_strategy:
inputs = [self._set_data_parallel(var) for var in inputs]
labels = [self._set_data_parallel(var) for var in labels]
# self._feed_vars[mode] = {"inputs": inputs, "labels": labels}
feed_vars = {"inputs": inputs, "labels": labels}
# self._fetch_vars[mode] = {
# "outputs": flatten(outputs),
# "loss": losses,
# "metrics": metrics
# }
fetch_vars = {
"outputs": flatten(outputs),
"loss": losses,
"metrics": metrics
}
self._dist_contexts[mode] = DistributedContext(
serial_main_prog, serial_startup_prog, self._optimizer, losses,
feed_vars, fetch_vars, self.cluster, self.strategy)
self._dist_contexts[mode].gradient_scale = self._gradient_scale
def _initialize(self, mode):
if self._nranks > 1: if self._nranks > 1:
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
...@@ -203,7 +261,7 @@ class Engine: ...@@ -203,7 +261,7 @@ class Engine:
# TODO: evaluate after training # TODO: evaluate after training
self.mode = 'train' self.mode = 'train'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare(mode='train')` first." "train model is not ready, please call `engine.prepare()` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch) epochs, steps_per_epoch)
...@@ -227,16 +285,19 @@ class Engine: ...@@ -227,16 +285,19 @@ class Engine:
return_numpy=True): return_numpy=True):
self.mode = 'eval' self.mode = 'eval'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare(mode='eval')` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data, batch_size)
outputs = []
for step, data in enumerate(eval_dataloader): for step, data in enumerate(eval_dataloader):
logs, outs = self._eval_step(data, use_program_cache, return_numpy) eval_logs = dict()
outputs.append(outs) outs = self._eval_step(data, use_program_cache, return_numpy)
predict_logs = {"eval_" + name: val for name, val in logs.items()} eval_logs["eval_loss"] = outs[0] if len(outs) > 0 else []
self._logger.info(predict_logs) for metric in self._metrics:
return outputs results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs["eval_" + metric.name()[i]] = res
self._logger.info(eval_logs)
return eval_logs
def predict(self, def predict(self,
test_data, test_data,
...@@ -245,7 +306,7 @@ class Engine: ...@@ -245,7 +306,7 @@ class Engine:
return_numpy=True): return_numpy=True):
self.mode = 'predict' self.mode = 'predict'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare(mode='predict')` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data, batch_size)
outputs = [] outputs = []
...@@ -262,57 +323,53 @@ class Engine: ...@@ -262,57 +323,53 @@ class Engine:
def _train_step(self, data, use_program_cache=False, return_numpy=True): def _train_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] fetch_vars = self._fetch_vars[self.mode]["loss"]
fetch_var = self._fetch_vars[self.mode]["loss"][0] fetch_list = self._fetch_list(fetch_vars)
if fetch_var.name not in dist_main_prog.global_block().vars:
loss = self._executor.run(dist_main_prog, loss = self._executor.run(self.main_program,
use_program_cache=use_program_cache) fetch_list=fetch_list,
logs["loss"] = None use_program_cache=use_program_cache,
else: return_numpy=return_numpy)
loss = self._executor.run(dist_main_prog, logs["loss"] = loss
fetch_list=to_list(fetch_var),
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["loss"] = loss
return logs, loss return logs, loss
def _eval_step(self, data, use_program_cache=False, return_numpy=True): def _eval_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] metrics = self._fetch_vars[self.mode]["metrics"]
fetch_var = self._fetch_vars[self.mode]["loss"][0] losses = self._fetch_vars[self.mode]["loss"]
fetch_loss = self._fetch_list(losses)
if fetch_var.name not in dist_main_prog.global_block().vars: fetch_metrics = self._fetch_list(metrics)
outs = self._executor.run(dist_main_prog, fetch_list = fetch_loss + fetch_metrics
use_program_cache=use_program_cache)
logs["loss"] = outs res = self._executor.run(self.main_program,
else: fetch_list=fetch_list,
outs = self._executor.run(dist_main_prog, use_program_cache=use_program_cache,
fetch_list=fetch_var, return_numpy=return_numpy)
use_program_cache=use_program_cache, if not res[len(fetch_loss):]:
return_numpy=return_numpy) return res[:len(fetch_loss)]
logs["loss"] = outs for metric in self._metrics:
return logs, outs metric.update(*res[len(fetch_loss):])
return res[:len(fetch_loss)]
def _predict_step(self, data, use_program_cache=False, return_numpy=True): def _predict_step(self, data, use_program_cache=False, return_numpy=True):
logs = {} logs = {}
dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank] fetch_vars = self._fetch_vars[self.mode]["outputs"]
fetch_var = [] fetch_list = self._fetch_list(fetch_vars)
for var in self._fetch_vars[self.mode]["outputs"]:
if var.name in dist_main_prog.global_block().vars: outs = self._executor.run(self.main_program,
fetch_var.append(var) fetch_list=fetch_list,
use_program_cache=use_program_cache,
if fetch_var is []: return_numpy=return_numpy)
outs = self._executor.run(dist_main_prog, logs["pred"] = outs
use_program_cache=use_program_cache)
logs["pred"] = outs
else:
outs = self._executor.run(dist_main_prog,
fetch_list=fetch_var,
use_program_cache=use_program_cache,
return_numpy=return_numpy)
logs["pred"] = outs
return logs, outs return logs, outs
def _fetch_list(self, fetch_vars):
fetch_list = []
for var in fetch_vars:
if var.name in self.main_program.global_block().vars:
fetch_list.append(var.name)
return fetch_list
def _create_dataloader(self, def _create_dataloader(self,
dataset, dataset,
batch_size, batch_size,
...@@ -323,7 +380,9 @@ class Engine: ...@@ -323,7 +380,9 @@ class Engine:
dist_context = self._dist_contexts[self.mode] dist_context = self._dist_contexts[self.mode]
dist_main_block = dist_main_prog.global_block() dist_main_block = dist_main_prog.global_block()
# get feed_list from dist_program # NOTE: Get feed_list from dist_program, then insert dataloader op
# with sharded var shape. Because predict_program does not contain
# labels var, so we will filter dataset's value with length of feed_list.
inputs_var = self._feed_vars[self.mode]["inputs"] inputs_var = self._feed_vars[self.mode]["inputs"]
labels_var = self._feed_vars[self.mode]["labels"] labels_var = self._feed_vars[self.mode]["labels"]
feed_list = [] feed_list = []
...@@ -342,7 +401,7 @@ class Engine: ...@@ -342,7 +401,7 @@ class Engine:
# insert read op at the end of program # insert read op at the end of program
places = paddle.static.cuda_places() places = paddle.static.cuda_places()
with fluid.program_guard(dist_main_prog, dist_startup_prog): with static.program_guard(dist_main_prog, dist_startup_prog):
dataloader = NonIterableGeneratorLoader( dataloader = NonIterableGeneratorLoader(
dataset, dataset,
feed_list, feed_list,
...@@ -468,10 +527,6 @@ class Engine: ...@@ -468,10 +527,6 @@ class Engine:
def mode(self, mode): def mode(self, mode):
self._mode = mode self._mode = mode
@property
def metrics(self):
return self._metrics
@property @property
def main_program(self): def main_program(self):
return self._dist_main_progs[self.mode][self._cur_rank] return self._dist_main_progs[self.mode][self._cur_rank]
......
...@@ -107,7 +107,6 @@ def train(): ...@@ -107,7 +107,6 @@ def train():
epsilon=1e-08, epsilon=1e-08,
grad_clip=None) grad_clip=None)
dataset = MyDataset(batch_num * batch_size)
inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x') inputs_spec = InputSpec([batch_size, hidden_size], 'float32', 'x')
labels_spec = InputSpec([batch_size], 'int64', 'label') labels_spec = InputSpec([batch_size], 'int64', 'label')
...@@ -119,23 +118,29 @@ def train(): ...@@ -119,23 +118,29 @@ def train():
dist_strategy.semi_auto = True dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy) fleet.init(is_collective=True, strategy=dist_strategy)
# init engine
engine = Engine( engine = Engine(
mlp, mlp,
inputs_spec=inputs_spec, inputs_spec=inputs_spec,
labels_spec=labels_spec, labels_spec=labels_spec,
strategy=dist_strategy) strategy=dist_strategy)
engine.prepare(optimizer, loss) engine.prepare(optimizer, loss, metrics=paddle.metric.Accuracy())
engine.fit(dataset,
# train
train_dataset = MyDataset(batch_num * batch_size)
engine.fit(train_dataset,
batch_size=batch_size, batch_size=batch_size,
steps_per_epoch=batch_num * batch_size) steps_per_epoch=batch_num * batch_size)
# eval
eval_dataset = MyDataset(batch_size) eval_dataset = MyDataset(batch_size)
engine.prepare(optimizer, loss, mode='eval')
engine.evaluate(eval_dataset, batch_size) engine.evaluate(eval_dataset, batch_size)
# predict
test_dataset = MyDataset(batch_size) test_dataset = MyDataset(batch_size)
engine.prepare(mode='predict')
engine.predict(test_dataset, batch_size) engine.predict(test_dataset, batch_size)
# save
engine.save('./mlp_inf', training=False, mode='predict') engine.save('./mlp_inf', training=False, mode='predict')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册