未验证 提交 c182e5dd 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner
上级 55accdfc
......@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.dataset = dataset
self.epochs = epochs
self.drop_lost = drop_last
self.drop_last = drop_last
if batch_size is None:
self.batch_size = None
......@@ -105,7 +105,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost)
self.collate_fn, self.drop_last)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
......@@ -153,7 +153,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_lost)
self.drop_last)
break
partial_data = []
......
......@@ -49,7 +49,7 @@ from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy
from .interface import _get_fetches
from .interface import CollectionNames, get_collection
class Engine:
......@@ -197,7 +197,7 @@ class Engine:
self._dygraph_mode = False
self._tuning = self._strategy.tuning
def _prepare_single_mode(self, mode):
def _prepare_program(self, mode):
# Do the build process
self._build(mode)
# Do the planning process
......@@ -208,6 +208,62 @@ class Engine:
self._initialize(mode)
self._mode_init_states[mode] = True
def _prepare_feed(self, user_feeds=None, mode="train"):
if user_feeds is not None:
assert isinstance(user_feeds, dict), \
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
feeds = {}
# TODO: add inputs and labels feed dict
for name, var in get_collection(CollectionNames.FEEDS):
assert name is not None, "No name defined for feed var"
feeds[name] = var
if user_feeds is not None:
for name, var in user_feeds.items():
feeds[name] = var
return feeds
def _prepare_fetch(self, user_fetches=None, mode="train"):
if user_fetches is not None:
assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
fetch_names = []
fetch_new_names = []
fetch_sections = {}
cnt = 0
def _process_section(section_name, var_list):
nonlocal cnt
section_start = cnt
for var in var_list:
new_name = None
# Rename the loss
if section_name == "loss":
new_name = "loss"
if isinstance(var, tuple):
assert len(var) == 2, "Length of tuple {} must be 2".format(
var)
new_name, var = var
if self._is_local_var(var) and var.name not in fetch_names:
fetch_names.append(var.name)
fetch_new_names.append(var.name)
cnt += 1
if self._is_local_var(var) and new_name is not None:
fetch_new_names[fetch_names.index(var.name)] = new_name
section_end = cnt
fetch_sections[section_name] = (section_start, section_end)
for name, var_list in self._fetch_vars[mode].items():
if name == "loss" and mode != "predict":
_process_section("loss", var_list)
if name == "metrics" and mode != "predict":
_process_section("metrics", var_list)
if name == "outputs" and mode == "predict":
_process_section("metrics", var_list)
var_list = (get_collection(CollectionNames.FETCHES)
or []) + (user_fetches or [])
_process_section("user_fetches", var_list)
return fetch_names, fetch_new_names, fetch_sections
def _build(self, mode):
if _non_static_mode() or self._dygraph_mode:
paddle.disable_static()
......@@ -427,30 +483,32 @@ class Engine:
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog)
def _infer_sample_spec(self, data, batch_size, split):
def _split_sample_item(self, data, split):
if isinstance(data, paddle.io.IterableDataset):
if split is None:
input, label = next(iter(data))
inputs, labels = next(iter(data))
else:
sample = next(iter(data))
input = sample[:split]
label = sample[split:]
inputs = sample[:split]
labels = sample[split:]
elif isinstance(data, paddle.io.Dataset):
if split is None:
input, label = data[0]
inputs, labels = data[0]
else:
sample = data[0]
input = sample[:split]
label = sample[split:]
inputs = sample[:split]
labels = sample[split:]
else:
raise ValueError(
"Data should be a Dataset or IterableDatset, but received {}.".
format(type(data).__name__))
inputs = to_list(inputs)
labels = to_list(labels)
return inputs, labels
def _infer_sample_spec(self, inputs, labels, batch_size):
self.inputs_spec = []
self.labels_spec = []
input_list = to_list(input)
label_list = to_list(label)
def _infer_item_spec(item, name, batch_size, specs):
if isinstance(item, np.ndarray):
......@@ -468,13 +526,13 @@ class Engine:
else:
specs.append(InputSpec([batch_size], type(item), name))
if input_list is not None:
for i, item in enumerate(input_list):
if inputs is not None:
for i, item in enumerate(inputs):
assert item is not None, "Receive None input."
name = "input" + str(i)
_infer_item_spec(item, name, batch_size, self.inputs_spec)
if label_list is not None:
for i, item in enumerate(label_list):
if labels is not None:
for i, item in enumerate(labels):
assert item is not None, "Receive None input."
name = "label" + str(i)
_infer_item_spec(item, name, batch_size, self.labels_spec)
......@@ -482,6 +540,65 @@ class Engine:
self.inputs_spec = self._validate_spec(self.inputs_spec)
self.labels_spec = self._validate_spec(self.labels_spec)
def __call__(self,
inputs=None,
labels=None,
feeds=None,
fetches=None,
mode="train"):
feed_dict = self._prepare_feed(feeds, mode)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetches, mode)
try:
outs = self._executor.run(
self.main_program,
feed=feed_dict,
fetch_list=fetch_list,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
pass
self._print_log(outs, self.mode, None, None, None, fetch_new_names,
fetch_sections)
return outs
# TODO: need a better to print the log
def _print_log(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_new_names=None,
fetch_sections=None):
prefix = "[{}] ".format(mode)
logs = {}
if epoch is not None:
logs["epoch: {:d} "] = epoch
if step is not None:
logs["step: {:d} "] = step
if lr is not None:
logs["lr: {:5e} "] = lr
if fetch_sections is not None:
assert fetch_new_names is not None
for section_name, section in fetch_sections.items():
section_start, section_end = section
if section_name == "metrics" and section_start < section_end:
metric_out = outs[section_start:section_end]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs[metric.name()[i] + ": {:8f} "] = res
elif section_name == "loss" and section_start < section_end:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
else:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys()))
self._logger.info(string.format(*list(logs.values())))
def fit(self,
train_data,
train_sample_split=None,
......@@ -561,28 +678,24 @@ class Engine:
batch_size=64)
"""
self.mode = 'train'
self._infer_sample_spec(train_data, batch_size, train_sample_split)
inputs, labels = self._split_sample_item(train_data, train_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
self._prepare_program(self.mode)
else:
self._switch_mode("train")
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine._prepare_single_mode('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size,
epochs, steps_per_epoch,
collate_fn)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
"train model is not ready, please call `engine._prepare_program('train')` first."
train_dataloader = self._prepare_dataloader(train_data, batch_size,
epochs, steps_per_epoch,
collate_fn)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
lr_scheduler = self._get_lr_scheduler(self.main_program)
outputs = defaultdict(list)
for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader):
try:
outs = self._executor.run(
......@@ -592,32 +705,11 @@ class Engine:
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
train_logs["step: {:d} "] = step
# update lr
if lr_scheduler and step % self._k_steps == 0:
lr_scheduler.step()
train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer)
# inner fetches
if fetch_loss:
train_logs["loss: {:8f} "] = outs[0][0]
outputs["loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
train_logs[metric.name()[i] + ": {:8f} "] = res
outputs[metric.name()[i]].append(outs[0][0])
# user fetches
user_outs = outs[len(inner_fetch):]
user_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(user_outs):
train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
# logger
string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values())))
lr = self._get_lr(self._lr_optimizer)
self._print_log(outs, self.mode, epoch, step, lr,
fetch_new_names, fetch_sections)
if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size,
......@@ -625,7 +717,7 @@ class Engine:
self._switch_mode("train")
else:
self._reset_metrics()
return outputs
return outs
def evaluate(self,
valid_data,
......@@ -652,7 +744,7 @@ class Engine:
the sample list, None for only stack each fields of sample in axis
0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply
during evaling. Default: None. (Unused for now)
during evaluating. Default: None. (Unused for now)
Returns:
None
......@@ -681,24 +773,22 @@ class Engine:
"""
self.mode = 'eval'
self._infer_sample_spec(valid_data, batch_size, valid_sample_split)
inputs, labels = self._split_sample_item(valid_data, valid_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
self._prepare_program(self.mode)
else:
self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine._prepare_single_mode('eval')` first."
valid_dataloader = self._create_dataloader(valid_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
"eval model is not ready, please call `engine._prepare_program('eval')` first."
valid_dataloader = self._prepare_dataloader(valid_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"])
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"])
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader):
......@@ -710,28 +800,8 @@ class Engine:
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
eval_logs = {"step: {:d} ": step}
# inner fetches
if fetch_loss:
eval_logs["loss: {:8f} "] = outs[0][0]
outputs["eval_loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs[metric.name()[i] + ": {:8f} "] = res
outputs["eval_" + metric.name()[i]].append(res)
# user fetches
usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs):
eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out
# logger
string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values())))
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
self._reset_metrics()
return outputs
......@@ -786,24 +856,23 @@ class Engine:
engine.predict(valid_dataset, batch_size=64)
"""
self.mode = 'predict'
self._infer_sample_spec(test_data, batch_size, test_sample_split)
inputs, labels = self._split_sample_item(test_data, test_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
self._prepare_program(self.mode)
else:
self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine._prepare_single_mode('predict')` first."
test_dataloader = self._create_dataloader(test_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
"predict model is not ready, please call `engine._prepare_program('predict')` first."
test_dataloader = self._prepare_dataloader(test_data,
batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"])
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
mode=self.mode)
outputs = []
for step, _ in enumerate(test_dataloader):
try:
outs = self._executor.run(
......@@ -813,27 +882,44 @@ class Engine:
return_numpy=self._strategy.return_numpy)
except core.EOFException:
break
predict_logs = {"step: {:d} ": step}
outputs.append(outs[:len(fetch_outputs)])
for i, out in enumerate(outs):
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
# logger
string = '[pred] ' + ''.join(list(predict_logs.keys()))
self._logger.info(string.format(*list(predict_logs.values())))
self._print_log(outs, self.mode, None, step, None, fetch_new_names,
fetch_sections)
return outputs
return outs
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train'
self._infer_sample_spec(tune_data, batch_size, tune_sample_split)
inputs, labels = self._split_sample_item(tune_data, tune_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
self._optimization_tuning(self.mode, tune_data, batch_size)
def _create_dataloader(self,
dataset,
batch_size,
epochs=1,
steps_per_epoch=None,
collate_fn=None):
def dataloader(self,
dataset,
sample_split=1,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
mode="train",
from_generator=True):
assert from_generator, "Only support from_generator for now"
self.mode = mode
inputs, labels = self._split_sample_item(dataset, sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]:
self._prepare_program(self.mode)
else:
self._switch_mode("train")
dataloader = self._prepare_dataloader(dataset, batch_size, epochs,
steps_per_epoch, collate_fn)
return dataloader
def _prepare_dataloader(self,
dataset,
batch_size,
epochs=1,
steps_per_epoch=None,
collate_fn=None):
if self._strategy.gradient_merge and batch_size is not None:
assert batch_size % self._k_steps == 0, \
......@@ -921,32 +1007,6 @@ class Engine:
var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars
def _validate_fetches(self, fetches):
# 1. Check user-defined fetches type
# 2. Prepare fetches_dict like {user_defined_name: var_name}
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
else:
raise TypeError("'fetches' only support 'dict' and 'list', "
"but got '{}'".format(str(type(fetches))))
return dict(
filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _get_input_split_info(self, var, dist_context):
# deduce how the input data is split among the cluster
from .utils import _get_comm_group, _get_corresponding_rank
......@@ -1066,7 +1126,7 @@ class Engine:
"""
if training:
assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine._prepare_single_mode('train')` first."
"training model is not ready, please call `engine._prepare_program('train')` first."
serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"]
......@@ -1097,7 +1157,7 @@ class Engine:
the parameter in file storing model states of or receives a
mismatch shape). Default: False.
load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is intialized
states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False.
Returns:
......
......@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
import paddle
from paddle.fluid import core
from .process_mesh import ProcessMesh
......@@ -196,15 +198,42 @@ def recompute(op):
return RecomputeOperator(op)
_g_fetched_tensors = {}
# _g_fetched_tensors = {}
# def fetch(tensor, name=None):
# if name is None:
# _g_fetched_tensors[tensor.name] = tensor
# else:
# _g_fetched_tensors[name] = tensor
def fetch(tensor, name=None):
if name is None:
_g_fetched_tensors[tensor.name] = tensor
# def _get_fetches():
# return _g_fetched_tensors
_g_collections = {}
class CollectionNames(object):
FEEDS = "feeds"
FETCHES = "fetches"
def get_collection(name):
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]
def add_to_collection(collection_name, value, value_name=None):
if collection_name not in _g_collections:
_g_collections[collection_name] = []
else:
_g_fetched_tensors[name] = tensor
if value_name is not None:
_g_collections[collection_name].append((value_name, value))
else:
_g_collections[collection_name].append((None, value))
def _get_fetches():
return _g_fetched_tensors
def fetch(tensor, name=None):
add_to_collection(CollectionNames.FETCHES, tensor, name)
......@@ -97,7 +97,7 @@ class MLPLayer(nn.Layer):
out = self.dropout(out)
out = self.linear2(out)
if is_fetch:
auto.fetch(out, "out")
auto.fetch(out, "my_out")
return out
......@@ -145,6 +145,57 @@ def train(fetch):
temp_dir.cleanup()
def train_callable():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
for _ in train_dataloader:
outs = engine(mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
for _ in eval_dataloader:
outs = engine(mode="eval")
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset,
batch_size=batch_size,
mode="predict")
for _ in predict_dataloader:
outs = engine(mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
if __name__ == "__main__":
train(fetch=True)
train(fetch=False)
train_callable()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册