未验证 提交 c182e5dd 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Make Engine class callable (#46416)

* [Auto Parallel] Imporve the user-defined fetches and logging

* [Auto Parallel] Make Engine class callable

* [Auto Parallel] Update the data loading of tuner
上级 55accdfc
...@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): ...@@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta):
self.dataset = dataset self.dataset = dataset
self.epochs = epochs self.epochs = epochs
self.drop_lost = drop_last self.drop_last = drop_last
if batch_size is None: if batch_size is None:
self.batch_size = None self.batch_size = None
...@@ -105,7 +105,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -105,7 +105,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_lost) self.collate_fn, self.drop_last)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
...@@ -153,7 +153,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): ...@@ -153,7 +153,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader):
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn, self.auto_collate_batch, self.collate_fn,
self.drop_lost) self.drop_last)
break break
partial_data = [] partial_data = []
......
...@@ -49,7 +49,7 @@ from .utils import get_logger, get_dist_attr ...@@ -49,7 +49,7 @@ from .utils import get_logger, get_dist_attr
from .process_group import new_process_group, get_all_process_groups from .process_group import new_process_group, get_all_process_groups
from .dist_context import DistributedContext, get_default_distributed_context from .dist_context import DistributedContext, get_default_distributed_context
from .strategy import Strategy from .strategy import Strategy
from .interface import _get_fetches from .interface import CollectionNames, get_collection
class Engine: class Engine:
...@@ -197,7 +197,7 @@ class Engine: ...@@ -197,7 +197,7 @@ class Engine:
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
def _prepare_single_mode(self, mode): def _prepare_program(self, mode):
# Do the build process # Do the build process
self._build(mode) self._build(mode)
# Do the planning process # Do the planning process
...@@ -208,6 +208,62 @@ class Engine: ...@@ -208,6 +208,62 @@ class Engine:
self._initialize(mode) self._initialize(mode)
self._mode_init_states[mode] = True self._mode_init_states[mode] = True
def _prepare_feed(self, user_feeds=None, mode="train"):
if user_feeds is not None:
assert isinstance(user_feeds, dict), \
"user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__)
feeds = {}
# TODO: add inputs and labels feed dict
for name, var in get_collection(CollectionNames.FEEDS):
assert name is not None, "No name defined for feed var"
feeds[name] = var
if user_feeds is not None:
for name, var in user_feeds.items():
feeds[name] = var
return feeds
def _prepare_fetch(self, user_fetches=None, mode="train"):
if user_fetches is not None:
assert isinstance(user_fetches, list), \
"user_fetches must be a list, but receive {}".format(type(user_fetches).__name__)
fetch_names = []
fetch_new_names = []
fetch_sections = {}
cnt = 0
def _process_section(section_name, var_list):
nonlocal cnt
section_start = cnt
for var in var_list:
new_name = None
# Rename the loss
if section_name == "loss":
new_name = "loss"
if isinstance(var, tuple):
assert len(var) == 2, "Length of tuple {} must be 2".format(
var)
new_name, var = var
if self._is_local_var(var) and var.name not in fetch_names:
fetch_names.append(var.name)
fetch_new_names.append(var.name)
cnt += 1
if self._is_local_var(var) and new_name is not None:
fetch_new_names[fetch_names.index(var.name)] = new_name
section_end = cnt
fetch_sections[section_name] = (section_start, section_end)
for name, var_list in self._fetch_vars[mode].items():
if name == "loss" and mode != "predict":
_process_section("loss", var_list)
if name == "metrics" and mode != "predict":
_process_section("metrics", var_list)
if name == "outputs" and mode == "predict":
_process_section("metrics", var_list)
var_list = (get_collection(CollectionNames.FETCHES)
or []) + (user_fetches or [])
_process_section("user_fetches", var_list)
return fetch_names, fetch_new_names, fetch_sections
def _build(self, mode): def _build(self, mode):
if _non_static_mode() or self._dygraph_mode: if _non_static_mode() or self._dygraph_mode:
paddle.disable_static() paddle.disable_static()
...@@ -427,30 +483,32 @@ class Engine: ...@@ -427,30 +483,32 @@ class Engine:
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
self._executor.run(dist_startup_prog) self._executor.run(dist_startup_prog)
def _infer_sample_spec(self, data, batch_size, split): def _split_sample_item(self, data, split):
if isinstance(data, paddle.io.IterableDataset): if isinstance(data, paddle.io.IterableDataset):
if split is None: if split is None:
input, label = next(iter(data)) inputs, labels = next(iter(data))
else: else:
sample = next(iter(data)) sample = next(iter(data))
input = sample[:split] inputs = sample[:split]
label = sample[split:] labels = sample[split:]
elif isinstance(data, paddle.io.Dataset): elif isinstance(data, paddle.io.Dataset):
if split is None: if split is None:
input, label = data[0] inputs, labels = data[0]
else: else:
sample = data[0] sample = data[0]
input = sample[:split] inputs = sample[:split]
label = sample[split:] labels = sample[split:]
else: else:
raise ValueError( raise ValueError(
"Data should be a Dataset or IterableDatset, but received {}.". "Data should be a Dataset or IterableDatset, but received {}.".
format(type(data).__name__)) format(type(data).__name__))
inputs = to_list(inputs)
labels = to_list(labels)
return inputs, labels
def _infer_sample_spec(self, inputs, labels, batch_size):
self.inputs_spec = [] self.inputs_spec = []
self.labels_spec = [] self.labels_spec = []
input_list = to_list(input)
label_list = to_list(label)
def _infer_item_spec(item, name, batch_size, specs): def _infer_item_spec(item, name, batch_size, specs):
if isinstance(item, np.ndarray): if isinstance(item, np.ndarray):
...@@ -468,13 +526,13 @@ class Engine: ...@@ -468,13 +526,13 @@ class Engine:
else: else:
specs.append(InputSpec([batch_size], type(item), name)) specs.append(InputSpec([batch_size], type(item), name))
if input_list is not None: if inputs is not None:
for i, item in enumerate(input_list): for i, item in enumerate(inputs):
assert item is not None, "Receive None input." assert item is not None, "Receive None input."
name = "input" + str(i) name = "input" + str(i)
_infer_item_spec(item, name, batch_size, self.inputs_spec) _infer_item_spec(item, name, batch_size, self.inputs_spec)
if label_list is not None: if labels is not None:
for i, item in enumerate(label_list): for i, item in enumerate(labels):
assert item is not None, "Receive None input." assert item is not None, "Receive None input."
name = "label" + str(i) name = "label" + str(i)
_infer_item_spec(item, name, batch_size, self.labels_spec) _infer_item_spec(item, name, batch_size, self.labels_spec)
...@@ -482,6 +540,65 @@ class Engine: ...@@ -482,6 +540,65 @@ class Engine:
self.inputs_spec = self._validate_spec(self.inputs_spec) self.inputs_spec = self._validate_spec(self.inputs_spec)
self.labels_spec = self._validate_spec(self.labels_spec) self.labels_spec = self._validate_spec(self.labels_spec)
def __call__(self,
inputs=None,
labels=None,
feeds=None,
fetches=None,
mode="train"):
feed_dict = self._prepare_feed(feeds, mode)
fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetches, mode)
try:
outs = self._executor.run(
self.main_program,
feed=feed_dict,
fetch_list=fetch_list,
use_program_cache=self._strategy.use_cache,
return_numpy=self._strategy.return_numpy)
except core.EOFException:
pass
self._print_log(outs, self.mode, None, None, None, fetch_new_names,
fetch_sections)
return outs
# TODO: need a better to print the log
def _print_log(self,
outs,
mode="train",
epoch=None,
step=None,
lr=None,
fetch_new_names=None,
fetch_sections=None):
prefix = "[{}] ".format(mode)
logs = {}
if epoch is not None:
logs["epoch: {:d} "] = epoch
if step is not None:
logs["step: {:d} "] = step
if lr is not None:
logs["lr: {:5e} "] = lr
if fetch_sections is not None:
assert fetch_new_names is not None
for section_name, section in fetch_sections.items():
section_start, section_end = section
if section_name == "metrics" and section_start < section_end:
metric_out = outs[section_start:section_end]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
logs[metric.name()[i] + ": {:8f} "] = res
elif section_name == "loss" and section_start < section_end:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0]
else:
for i in range(section_start, section_end):
logs[fetch_new_names[i] + ": {} "] = outs[i]
string = prefix + ''.join(list(logs.keys()))
self._logger.info(string.format(*list(logs.values())))
def fit(self, def fit(self,
train_data, train_data,
train_sample_split=None, train_sample_split=None,
...@@ -561,28 +678,24 @@ class Engine: ...@@ -561,28 +678,24 @@ class Engine:
batch_size=64) batch_size=64)
""" """
self.mode = 'train' self.mode = 'train'
self._infer_sample_spec(train_data, batch_size, train_sample_split) inputs, labels = self._split_sample_item(train_data, train_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_program(self.mode)
else: else:
self._switch_mode("train") self._switch_mode("train")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine._prepare_single_mode('train')` first." "train model is not ready, please call `engine._prepare_program('train')` first."
train_dataloader = self._create_dataloader(train_data, batch_size, train_dataloader = self._prepare_dataloader(train_data, batch_size,
epochs, steps_per_epoch, epochs, steps_per_epoch,
collate_fn) collate_fn)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) mode=self.mode)
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
lr_scheduler = self._get_lr_scheduler(self.main_program) lr_scheduler = self._get_lr_scheduler(self.main_program)
outputs = defaultdict(list)
for epoch in range(epochs): for epoch in range(epochs):
train_logs = {"epoch: {:d} ": epoch}
for step, _ in enumerate(train_dataloader): for step, _ in enumerate(train_dataloader):
try: try:
outs = self._executor.run( outs = self._executor.run(
...@@ -592,32 +705,11 @@ class Engine: ...@@ -592,32 +705,11 @@ class Engine:
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
break break
train_logs["step: {:d} "] = step
# update lr
if lr_scheduler and step % self._k_steps == 0: if lr_scheduler and step % self._k_steps == 0:
lr_scheduler.step() lr_scheduler.step()
train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer) lr = self._get_lr(self._lr_optimizer)
# inner fetches self._print_log(outs, self.mode, epoch, step, lr,
if fetch_loss: fetch_new_names, fetch_sections)
train_logs["loss: {:8f} "] = outs[0][0]
outputs["loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
train_logs[metric.name()[i] + ": {:8f} "] = res
outputs[metric.name()[i]].append(outs[0][0])
# user fetches
user_outs = outs[len(inner_fetch):]
user_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(user_outs):
train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out
# logger
string = '[train] ' + ''.join(list(train_logs.keys()))
self._logger.info(string.format(*list(train_logs.values())))
if valid_data and epoch % valid_freq == 0: if valid_data and epoch % valid_freq == 0:
self.evaluate(valid_data, valid_sample_split, batch_size, self.evaluate(valid_data, valid_sample_split, batch_size,
...@@ -625,7 +717,7 @@ class Engine: ...@@ -625,7 +717,7 @@ class Engine:
self._switch_mode("train") self._switch_mode("train")
else: else:
self._reset_metrics() self._reset_metrics()
return outputs return outs
def evaluate(self, def evaluate(self,
valid_data, valid_data,
...@@ -652,7 +744,7 @@ class Engine: ...@@ -652,7 +744,7 @@ class Engine:
the sample list, None for only stack each fields of sample in axis the sample list, None for only stack each fields of sample in axis
0. Default None. 0. Default None.
callbacks (Callback|None, optional): A list of `Callback` instances to apply callbacks (Callback|None, optional): A list of `Callback` instances to apply
during evaling. Default: None. (Unused for now) during evaluating. Default: None. (Unused for now)
Returns: Returns:
None None
...@@ -681,24 +773,22 @@ class Engine: ...@@ -681,24 +773,22 @@ class Engine:
""" """
self.mode = 'eval' self.mode = 'eval'
self._infer_sample_spec(valid_data, batch_size, valid_sample_split) inputs, labels = self._split_sample_item(valid_data, valid_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_program(self.mode)
else: else:
self._switch_mode("eval") self._switch_mode("eval")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine._prepare_single_mode('eval')` first." "eval model is not ready, please call `engine._prepare_program('eval')` first."
valid_dataloader = self._create_dataloader(valid_data, valid_dataloader = self._prepare_dataloader(valid_data,
batch_size, batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) mode=self.mode)
inner_fetch = dict(fetch_loss, **fetch_metrics)
usr_fetch = self._validate_fetches(_get_fetches())
fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch)
outputs = defaultdict(list) outputs = defaultdict(list)
for step, _ in enumerate(valid_dataloader): for step, _ in enumerate(valid_dataloader):
...@@ -710,28 +800,8 @@ class Engine: ...@@ -710,28 +800,8 @@ class Engine:
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
break break
eval_logs = {"step: {:d} ": step} self._print_log(outs, self.mode, None, step, None, fetch_new_names,
# inner fetches fetch_sections)
if fetch_loss:
eval_logs["loss: {:8f} "] = outs[0][0]
outputs["eval_loss"].append(outs[0][0])
# Metric
if fetch_metrics:
metric_out = outs[len(fetch_loss):len(inner_fetch)]
for metric in self._metrics:
metric.update(*metric_out)
results = metric.accumulate()
for i, res in enumerate(to_list(results)):
eval_logs[metric.name()[i] + ": {:8f} "] = res
outputs["eval_" + metric.name()[i]].append(res)
# user fetches
usr_outs = outs[len(inner_fetch):]
usr_fetch_list = fetch_list[len(inner_fetch):]
for i, out in enumerate(usr_outs):
eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out
# logger
string = '[eval] ' + ''.join(list(eval_logs.keys()))
self._logger.info(string.format(*list(eval_logs.values())))
self._reset_metrics() self._reset_metrics()
return outputs return outputs
...@@ -786,24 +856,23 @@ class Engine: ...@@ -786,24 +856,23 @@ class Engine:
engine.predict(valid_dataset, batch_size=64) engine.predict(valid_dataset, batch_size=64)
""" """
self.mode = 'predict' self.mode = 'predict'
self._infer_sample_spec(test_data, batch_size, test_sample_split) inputs, labels = self._split_sample_item(test_data, test_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]: if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode) self._prepare_program(self.mode)
else: else:
self._switch_mode("predict") self._switch_mode("predict")
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine._prepare_single_mode('predict')` first." "predict model is not ready, please call `engine._prepare_program('predict')` first."
test_dataloader = self._create_dataloader(test_data, test_dataloader = self._prepare_dataloader(test_data,
batch_size, batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn) collate_fn=collate_fn)
fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch(
usr_fetch = self._validate_fetches(_get_fetches()) mode=self.mode)
fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch)
outputs = []
for step, _ in enumerate(test_dataloader): for step, _ in enumerate(test_dataloader):
try: try:
outs = self._executor.run( outs = self._executor.run(
...@@ -813,22 +882,39 @@ class Engine: ...@@ -813,22 +882,39 @@ class Engine:
return_numpy=self._strategy.return_numpy) return_numpy=self._strategy.return_numpy)
except core.EOFException: except core.EOFException:
break break
predict_logs = {"step: {:d} ": step} self._print_log(outs, self.mode, None, step, None, fetch_new_names,
outputs.append(outs[:len(fetch_outputs)]) fetch_sections)
for i, out in enumerate(outs):
predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out
# logger
string = '[pred] ' + ''.join(list(predict_logs.keys()))
self._logger.info(string.format(*list(predict_logs.values())))
return outputs return outs
def _tune(self, tune_data, tune_sample_split=None, batch_size=1): def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
self.mode = 'train' self.mode = 'train'
self._infer_sample_spec(tune_data, batch_size, tune_sample_split) inputs, labels = self._split_sample_item(tune_data, tune_sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
self._optimization_tuning(self.mode, tune_data, batch_size) self._optimization_tuning(self.mode, tune_data, batch_size)
def _create_dataloader(self, def dataloader(self,
dataset,
sample_split=1,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
mode="train",
from_generator=True):
assert from_generator, "Only support from_generator for now"
self.mode = mode
inputs, labels = self._split_sample_item(dataset, sample_split)
self._infer_sample_spec(inputs, labels, batch_size)
if not self._mode_init_states[self.mode]:
self._prepare_program(self.mode)
else:
self._switch_mode("train")
dataloader = self._prepare_dataloader(dataset, batch_size, epochs,
steps_per_epoch, collate_fn)
return dataloader
def _prepare_dataloader(self,
dataset, dataset,
batch_size, batch_size,
epochs=1, epochs=1,
...@@ -921,32 +1007,6 @@ class Engine: ...@@ -921,32 +1007,6 @@ class Engine:
var_name = _to_name_str(var) var_name = _to_name_str(var)
return var_name in self.main_program.global_block().vars return var_name in self.main_program.global_block().vars
def _validate_fetches(self, fetches):
# 1. Check user-defined fetches type
# 2. Prepare fetches_dict like {user_defined_name: var_name}
if not fetches:
return {}
if isinstance(fetches, dict):
fetch_var_names = list(map(_to_name_str, fetches.values()))
fetches_dict = dict(zip(fetch_var_names, list(fetches.keys())))
elif isinstance(fetches, list):
fetch_var_names = list(map(_to_name_str, fetches))
fetches_dict = dict(zip(fetch_var_names, fetch_var_names))
else:
raise TypeError("'fetches' only support 'dict' and 'list', "
"but got '{}'".format(str(type(fetches))))
return dict(
filter(lambda x: self._is_local_var(x[0]), fetches_dict.items()))
def _fetch_map(self, inner_fetch, usr_fetch):
# replace inner fetch name if usr set for it
for iname in inner_fetch:
if iname in usr_fetch:
inner_fetch[iname] = usr_fetch[iname]
usr_fetch.pop(iname)
fetches = dict(inner_fetch, **usr_fetch)
return list(fetches.keys()), fetches
def _get_input_split_info(self, var, dist_context): def _get_input_split_info(self, var, dist_context):
# deduce how the input data is split among the cluster # deduce how the input data is split among the cluster
from .utils import _get_comm_group, _get_corresponding_rank from .utils import _get_comm_group, _get_corresponding_rank
...@@ -1066,7 +1126,7 @@ class Engine: ...@@ -1066,7 +1126,7 @@ class Engine:
""" """
if training: if training:
assert 'train' in self._serial_main_progs, \ assert 'train' in self._serial_main_progs, \
"training model is not ready, please call `engine._prepare_single_mode('train')` first." "training model is not ready, please call `engine._prepare_program('train')` first."
serial_program = self._serial_main_progs["train"] serial_program = self._serial_main_progs["train"]
dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_main_prog = self._dist_main_progs["train"][self._cur_rank]
dist_context = self._dist_contexts["train"] dist_context = self._dist_contexts["train"]
...@@ -1097,7 +1157,7 @@ class Engine: ...@@ -1097,7 +1157,7 @@ class Engine:
the parameter in file storing model states of or receives a the parameter in file storing model states of or receives a
mismatch shape). Default: False. mismatch shape). Default: False.
load_optimizer (bool, optional): If True, the stored optimizer load_optimizer (bool, optional): If True, the stored optimizer
states is restored. Otherwise, the optimizer states is intialized states is restored. Otherwise, the optimizer states is initialized
from scratch. Default: False. from scratch. Default: False.
Returns: Returns:
......
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from collections import defaultdict
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
from .process_mesh import ProcessMesh from .process_mesh import ProcessMesh
...@@ -196,15 +198,42 @@ def recompute(op): ...@@ -196,15 +198,42 @@ def recompute(op):
return RecomputeOperator(op) return RecomputeOperator(op)
_g_fetched_tensors = {} # _g_fetched_tensors = {}
# def fetch(tensor, name=None):
# if name is None:
# _g_fetched_tensors[tensor.name] = tensor
# else:
# _g_fetched_tensors[name] = tensor
def fetch(tensor, name=None): # def _get_fetches():
if name is None: # return _g_fetched_tensors
_g_fetched_tensors[tensor.name] = tensor
_g_collections = {}
class CollectionNames(object):
FEEDS = "feeds"
FETCHES = "fetches"
def get_collection(name):
collection = _g_collections.get(name, None)
if collection is None:
collection = []
_g_collections[name] = collection
return _g_collections[name]
def add_to_collection(collection_name, value, value_name=None):
if collection_name not in _g_collections:
_g_collections[collection_name] = []
else: else:
_g_fetched_tensors[name] = tensor if value_name is not None:
_g_collections[collection_name].append((value_name, value))
else:
_g_collections[collection_name].append((None, value))
def _get_fetches(): def fetch(tensor, name=None):
return _g_fetched_tensors add_to_collection(CollectionNames.FETCHES, tensor, name)
...@@ -97,7 +97,7 @@ class MLPLayer(nn.Layer): ...@@ -97,7 +97,7 @@ class MLPLayer(nn.Layer):
out = self.dropout(out) out = self.dropout(out)
out = self.linear2(out) out = self.linear2(out)
if is_fetch: if is_fetch:
auto.fetch(out, "out") auto.fetch(out, "my_out")
return out return out
...@@ -145,6 +145,57 @@ def train(fetch): ...@@ -145,6 +145,57 @@ def train(fetch):
temp_dir.cleanup() temp_dir.cleanup()
def train_callable():
mlp = MLPLayer(hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
loss = paddle.nn.CrossEntropyLoss()
optimizer = paddle.optimizer.Adam(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
metric = paddle.metric.Accuracy()
strategy = auto.Strategy()
strategy.auto_mode = "semi"
engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy)
# train
train_dataset = MyDataset(batch_num * batch_size)
train_dataloader = engine.dataloader(train_dataset,
batch_size=batch_size,
mode="train")
for _ in train_dataloader:
outs = engine(mode="train")
# eval
eval_dataset2 = MyDataset(batch_size)
eval_dataloader = engine.dataloader(eval_dataset2,
batch_size=batch_size,
mode="eval")
for _ in eval_dataloader:
outs = engine(mode="eval")
# predict
test_dataset = MyDataset(batch_size)
predict_dataloader = engine.dataloader(test_dataset,
batch_size=batch_size,
mode="predict")
for _ in predict_dataloader:
outs = engine(mode="predict")
# save
temp_dir = tempfile.TemporaryDirectory()
model_filename = os.path.join(temp_dir.name, 'mlp')
engine.save(model_filename, training=True)
engine.load(model_filename)
temp_dir.cleanup()
if __name__ == "__main__": if __name__ == "__main__":
train(fetch=True) train(fetch=True)
train(fetch=False) train(fetch=False)
train_callable()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册