diff --git a/python/paddle/distributed/auto_parallel/dist_loader.py b/python/paddle/distributed/auto_parallel/dist_loader.py index 44f720ade7f80c9771869a91dcaa0a0c6131d396..229d1b24fbd996e212324d017461322f42f8c5b1 100644 --- a/python/paddle/distributed/auto_parallel/dist_loader.py +++ b/python/paddle/distributed/auto_parallel/dist_loader.py @@ -34,7 +34,7 @@ class DistributedDataLoader(metaclass=abc.ABCMeta): self.dataset = dataset self.epochs = epochs - self.drop_lost = drop_last + self.drop_last = drop_last if batch_size is None: self.batch_size = None @@ -105,7 +105,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): self.collate_fn = collate_fn or default_convert_fn self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_kind, self.dataset, self.auto_collate_batch, - self.collate_fn, self.drop_lost) + self.collate_fn, self.drop_last) self._steps = self._infer_steps() self._inner_dataloader = self._create_inner_dataloader() @@ -153,7 +153,7 @@ class NonIterableGeneratorLoader(DistributedDataLoader): self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_kind, self.dataset, self.auto_collate_batch, self.collate_fn, - self.drop_lost) + self.drop_last) break partial_data = [] diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index e329f775590f5a45b4c7ec9809422b8c8e994ced..8c2480b67d845b1ead66fdaf8e4b8ea064a24e1f 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -49,7 +49,7 @@ from .utils import get_logger, get_dist_attr from .process_group import new_process_group, get_all_process_groups from .dist_context import DistributedContext, get_default_distributed_context from .strategy import Strategy -from .interface import _get_fetches +from .interface import CollectionNames, get_collection class Engine: @@ -197,7 +197,7 @@ class Engine: self._dygraph_mode = False self._tuning = self._strategy.tuning - def _prepare_single_mode(self, mode): + def _prepare_program(self, mode): # Do the build process self._build(mode) # Do the planning process @@ -208,6 +208,62 @@ class Engine: self._initialize(mode) self._mode_init_states[mode] = True + def _prepare_feed(self, user_feeds=None, mode="train"): + if user_feeds is not None: + assert isinstance(user_feeds, dict), \ + "user_feeds must be a dict, but receive {}".format(type(user_feeds).__name__) + feeds = {} + # TODO: add inputs and labels feed dict + for name, var in get_collection(CollectionNames.FEEDS): + assert name is not None, "No name defined for feed var" + feeds[name] = var + if user_feeds is not None: + for name, var in user_feeds.items(): + feeds[name] = var + return feeds + + def _prepare_fetch(self, user_fetches=None, mode="train"): + if user_fetches is not None: + assert isinstance(user_fetches, list), \ + "user_fetches must be a list, but receive {}".format(type(user_fetches).__name__) + fetch_names = [] + fetch_new_names = [] + fetch_sections = {} + cnt = 0 + + def _process_section(section_name, var_list): + nonlocal cnt + section_start = cnt + for var in var_list: + new_name = None + # Rename the loss + if section_name == "loss": + new_name = "loss" + if isinstance(var, tuple): + assert len(var) == 2, "Length of tuple {} must be 2".format( + var) + new_name, var = var + if self._is_local_var(var) and var.name not in fetch_names: + fetch_names.append(var.name) + fetch_new_names.append(var.name) + cnt += 1 + if self._is_local_var(var) and new_name is not None: + fetch_new_names[fetch_names.index(var.name)] = new_name + section_end = cnt + fetch_sections[section_name] = (section_start, section_end) + + for name, var_list in self._fetch_vars[mode].items(): + if name == "loss" and mode != "predict": + _process_section("loss", var_list) + if name == "metrics" and mode != "predict": + _process_section("metrics", var_list) + if name == "outputs" and mode == "predict": + _process_section("metrics", var_list) + var_list = (get_collection(CollectionNames.FETCHES) + or []) + (user_fetches or []) + _process_section("user_fetches", var_list) + return fetch_names, fetch_new_names, fetch_sections + def _build(self, mode): if _non_static_mode() or self._dygraph_mode: paddle.disable_static() @@ -427,30 +483,32 @@ class Engine: dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] self._executor.run(dist_startup_prog) - def _infer_sample_spec(self, data, batch_size, split): + def _split_sample_item(self, data, split): if isinstance(data, paddle.io.IterableDataset): if split is None: - input, label = next(iter(data)) + inputs, labels = next(iter(data)) else: sample = next(iter(data)) - input = sample[:split] - label = sample[split:] + inputs = sample[:split] + labels = sample[split:] elif isinstance(data, paddle.io.Dataset): if split is None: - input, label = data[0] + inputs, labels = data[0] else: sample = data[0] - input = sample[:split] - label = sample[split:] + inputs = sample[:split] + labels = sample[split:] else: raise ValueError( "Data should be a Dataset or IterableDatset, but received {}.". format(type(data).__name__)) + inputs = to_list(inputs) + labels = to_list(labels) + return inputs, labels + def _infer_sample_spec(self, inputs, labels, batch_size): self.inputs_spec = [] self.labels_spec = [] - input_list = to_list(input) - label_list = to_list(label) def _infer_item_spec(item, name, batch_size, specs): if isinstance(item, np.ndarray): @@ -468,13 +526,13 @@ class Engine: else: specs.append(InputSpec([batch_size], type(item), name)) - if input_list is not None: - for i, item in enumerate(input_list): + if inputs is not None: + for i, item in enumerate(inputs): assert item is not None, "Receive None input." name = "input" + str(i) _infer_item_spec(item, name, batch_size, self.inputs_spec) - if label_list is not None: - for i, item in enumerate(label_list): + if labels is not None: + for i, item in enumerate(labels): assert item is not None, "Receive None input." name = "label" + str(i) _infer_item_spec(item, name, batch_size, self.labels_spec) @@ -482,6 +540,65 @@ class Engine: self.inputs_spec = self._validate_spec(self.inputs_spec) self.labels_spec = self._validate_spec(self.labels_spec) + def __call__(self, + inputs=None, + labels=None, + feeds=None, + fetches=None, + mode="train"): + feed_dict = self._prepare_feed(feeds, mode) + fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( + fetches, mode) + try: + outs = self._executor.run( + self.main_program, + feed=feed_dict, + fetch_list=fetch_list, + use_program_cache=self._strategy.use_cache, + return_numpy=self._strategy.return_numpy) + except core.EOFException: + pass + self._print_log(outs, self.mode, None, None, None, fetch_new_names, + fetch_sections) + return outs + + # TODO: need a better to print the log + def _print_log(self, + outs, + mode="train", + epoch=None, + step=None, + lr=None, + fetch_new_names=None, + fetch_sections=None): + prefix = "[{}] ".format(mode) + logs = {} + if epoch is not None: + logs["epoch: {:d} "] = epoch + if step is not None: + logs["step: {:d} "] = step + if lr is not None: + logs["lr: {:5e} "] = lr + if fetch_sections is not None: + assert fetch_new_names is not None + for section_name, section in fetch_sections.items(): + section_start, section_end = section + if section_name == "metrics" and section_start < section_end: + metric_out = outs[section_start:section_end] + for metric in self._metrics: + metric.update(*metric_out) + results = metric.accumulate() + for i, res in enumerate(to_list(results)): + logs[metric.name()[i] + ": {:8f} "] = res + elif section_name == "loss" and section_start < section_end: + for i in range(section_start, section_end): + logs[fetch_new_names[i] + ": {:8f} "] = outs[i][0] + else: + for i in range(section_start, section_end): + logs[fetch_new_names[i] + ": {} "] = outs[i] + string = prefix + ''.join(list(logs.keys())) + self._logger.info(string.format(*list(logs.values()))) + def fit(self, train_data, train_sample_split=None, @@ -561,28 +678,24 @@ class Engine: batch_size=64) """ self.mode = 'train' - self._infer_sample_spec(train_data, batch_size, train_sample_split) + inputs, labels = self._split_sample_item(train_data, train_sample_split) + self._infer_sample_spec(inputs, labels, batch_size) if not self._mode_init_states[self.mode]: - self._prepare_single_mode(self.mode) + self._prepare_program(self.mode) else: self._switch_mode("train") assert self.mode in self._dist_main_progs, \ - "train model is not ready, please call `engine._prepare_single_mode('train')` first." - train_dataloader = self._create_dataloader(train_data, batch_size, - epochs, steps_per_epoch, - collate_fn) - - fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) - fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) - inner_fetch = dict(fetch_loss, **fetch_metrics) - usr_fetch = self._validate_fetches(_get_fetches()) - fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) + "train model is not ready, please call `engine._prepare_program('train')` first." + train_dataloader = self._prepare_dataloader(train_data, batch_size, + epochs, steps_per_epoch, + collate_fn) + + fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( + mode=self.mode) lr_scheduler = self._get_lr_scheduler(self.main_program) - outputs = defaultdict(list) for epoch in range(epochs): - train_logs = {"epoch: {:d} ": epoch} for step, _ in enumerate(train_dataloader): try: outs = self._executor.run( @@ -592,32 +705,11 @@ class Engine: return_numpy=self._strategy.return_numpy) except core.EOFException: break - train_logs["step: {:d} "] = step - # update lr if lr_scheduler and step % self._k_steps == 0: lr_scheduler.step() - train_logs["lr: {:5e} "] = self._get_lr(self._lr_optimizer) - # inner fetches - if fetch_loss: - train_logs["loss: {:8f} "] = outs[0][0] - outputs["loss"].append(outs[0][0]) - # Metric - if fetch_metrics: - metric_out = outs[len(fetch_loss):len(inner_fetch)] - for metric in self._metrics: - metric.update(*metric_out) - results = metric.accumulate() - for i, res in enumerate(to_list(results)): - train_logs[metric.name()[i] + ": {:8f} "] = res - outputs[metric.name()[i]].append(outs[0][0]) - # user fetches - user_outs = outs[len(inner_fetch):] - user_fetch_list = fetch_list[len(inner_fetch):] - for i, out in enumerate(user_outs): - train_logs[fetch_map[user_fetch_list[i]] + ": {}"] = out - # logger - string = '[train] ' + ''.join(list(train_logs.keys())) - self._logger.info(string.format(*list(train_logs.values()))) + lr = self._get_lr(self._lr_optimizer) + self._print_log(outs, self.mode, epoch, step, lr, + fetch_new_names, fetch_sections) if valid_data and epoch % valid_freq == 0: self.evaluate(valid_data, valid_sample_split, batch_size, @@ -625,7 +717,7 @@ class Engine: self._switch_mode("train") else: self._reset_metrics() - return outputs + return outs def evaluate(self, valid_data, @@ -652,7 +744,7 @@ class Engine: the sample list, None for only stack each fields of sample in axis 0. Default None. callbacks (Callback|None, optional): A list of `Callback` instances to apply - during evaling. Default: None. (Unused for now) + during evaluating. Default: None. (Unused for now) Returns: None @@ -681,24 +773,22 @@ class Engine: """ self.mode = 'eval' - self._infer_sample_spec(valid_data, batch_size, valid_sample_split) + inputs, labels = self._split_sample_item(valid_data, valid_sample_split) + self._infer_sample_spec(inputs, labels, batch_size) if not self._mode_init_states[self.mode]: - self._prepare_single_mode(self.mode) + self._prepare_program(self.mode) else: self._switch_mode("eval") assert self.mode in self._dist_main_progs, \ - "eval model is not ready, please call `engine._prepare_single_mode('eval')` first." - valid_dataloader = self._create_dataloader(valid_data, - batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn) + "eval model is not ready, please call `engine._prepare_program('eval')` first." + valid_dataloader = self._prepare_dataloader(valid_data, + batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn) - fetch_loss = self._validate_fetches(self.fetch_vars["loss"]) - fetch_metrics = self._validate_fetches(self.fetch_vars["metrics"]) - inner_fetch = dict(fetch_loss, **fetch_metrics) - usr_fetch = self._validate_fetches(_get_fetches()) - fetch_list, fetch_map = self._fetch_map(inner_fetch, usr_fetch) + fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( + mode=self.mode) outputs = defaultdict(list) for step, _ in enumerate(valid_dataloader): @@ -710,28 +800,8 @@ class Engine: return_numpy=self._strategy.return_numpy) except core.EOFException: break - eval_logs = {"step: {:d} ": step} - # inner fetches - if fetch_loss: - eval_logs["loss: {:8f} "] = outs[0][0] - outputs["eval_loss"].append(outs[0][0]) - # Metric - if fetch_metrics: - metric_out = outs[len(fetch_loss):len(inner_fetch)] - for metric in self._metrics: - metric.update(*metric_out) - results = metric.accumulate() - for i, res in enumerate(to_list(results)): - eval_logs[metric.name()[i] + ": {:8f} "] = res - outputs["eval_" + metric.name()[i]].append(res) - # user fetches - usr_outs = outs[len(inner_fetch):] - usr_fetch_list = fetch_list[len(inner_fetch):] - for i, out in enumerate(usr_outs): - eval_logs[fetch_map[usr_fetch_list[i]] + ": {}"] = out - # logger - string = '[eval] ' + ''.join(list(eval_logs.keys())) - self._logger.info(string.format(*list(eval_logs.values()))) + self._print_log(outs, self.mode, None, step, None, fetch_new_names, + fetch_sections) self._reset_metrics() return outputs @@ -786,24 +856,23 @@ class Engine: engine.predict(valid_dataset, batch_size=64) """ self.mode = 'predict' - self._infer_sample_spec(test_data, batch_size, test_sample_split) + inputs, labels = self._split_sample_item(test_data, test_sample_split) + self._infer_sample_spec(inputs, labels, batch_size) if not self._mode_init_states[self.mode]: - self._prepare_single_mode(self.mode) + self._prepare_program(self.mode) else: self._switch_mode("predict") assert self.mode in self._dist_main_progs, \ - "predict model is not ready, please call `engine._prepare_single_mode('predict')` first." - test_dataloader = self._create_dataloader(test_data, - batch_size, - steps_per_epoch=steps, - collate_fn=collate_fn) + "predict model is not ready, please call `engine._prepare_program('predict')` first." + test_dataloader = self._prepare_dataloader(test_data, + batch_size, + steps_per_epoch=steps, + collate_fn=collate_fn) - fetch_outputs = self._validate_fetches(self.fetch_vars["outputs"]) - usr_fetch = self._validate_fetches(_get_fetches()) - fetch_list, fetch_map = self._fetch_map(fetch_outputs, usr_fetch) + fetch_list, fetch_new_names, fetch_sections = self._prepare_fetch( + mode=self.mode) - outputs = [] for step, _ in enumerate(test_dataloader): try: outs = self._executor.run( @@ -813,27 +882,44 @@ class Engine: return_numpy=self._strategy.return_numpy) except core.EOFException: break - predict_logs = {"step: {:d} ": step} - outputs.append(outs[:len(fetch_outputs)]) - for i, out in enumerate(outs): - predict_logs[fetch_map[fetch_list[i]] + ": {}"] = out - # logger - string = '[pred] ' + ''.join(list(predict_logs.keys())) - self._logger.info(string.format(*list(predict_logs.values()))) + self._print_log(outs, self.mode, None, step, None, fetch_new_names, + fetch_sections) - return outputs + return outs def _tune(self, tune_data, tune_sample_split=None, batch_size=1): self.mode = 'train' - self._infer_sample_spec(tune_data, batch_size, tune_sample_split) + inputs, labels = self._split_sample_item(tune_data, tune_sample_split) + self._infer_sample_spec(inputs, labels, batch_size) self._optimization_tuning(self.mode, tune_data, batch_size) - def _create_dataloader(self, - dataset, - batch_size, - epochs=1, - steps_per_epoch=None, - collate_fn=None): + def dataloader(self, + dataset, + sample_split=1, + batch_size=1, + epochs=1, + steps_per_epoch=None, + collate_fn=None, + mode="train", + from_generator=True): + assert from_generator, "Only support from_generator for now" + self.mode = mode + inputs, labels = self._split_sample_item(dataset, sample_split) + self._infer_sample_spec(inputs, labels, batch_size) + if not self._mode_init_states[self.mode]: + self._prepare_program(self.mode) + else: + self._switch_mode("train") + dataloader = self._prepare_dataloader(dataset, batch_size, epochs, + steps_per_epoch, collate_fn) + return dataloader + + def _prepare_dataloader(self, + dataset, + batch_size, + epochs=1, + steps_per_epoch=None, + collate_fn=None): if self._strategy.gradient_merge and batch_size is not None: assert batch_size % self._k_steps == 0, \ @@ -921,32 +1007,6 @@ class Engine: var_name = _to_name_str(var) return var_name in self.main_program.global_block().vars - def _validate_fetches(self, fetches): - # 1. Check user-defined fetches type - # 2. Prepare fetches_dict like {user_defined_name: var_name} - if not fetches: - return {} - if isinstance(fetches, dict): - fetch_var_names = list(map(_to_name_str, fetches.values())) - fetches_dict = dict(zip(fetch_var_names, list(fetches.keys()))) - elif isinstance(fetches, list): - fetch_var_names = list(map(_to_name_str, fetches)) - fetches_dict = dict(zip(fetch_var_names, fetch_var_names)) - else: - raise TypeError("'fetches' only support 'dict' and 'list', " - "but got '{}'".format(str(type(fetches)))) - return dict( - filter(lambda x: self._is_local_var(x[0]), fetches_dict.items())) - - def _fetch_map(self, inner_fetch, usr_fetch): - # replace inner fetch name if usr set for it - for iname in inner_fetch: - if iname in usr_fetch: - inner_fetch[iname] = usr_fetch[iname] - usr_fetch.pop(iname) - fetches = dict(inner_fetch, **usr_fetch) - return list(fetches.keys()), fetches - def _get_input_split_info(self, var, dist_context): # deduce how the input data is split among the cluster from .utils import _get_comm_group, _get_corresponding_rank @@ -1066,7 +1126,7 @@ class Engine: """ if training: assert 'train' in self._serial_main_progs, \ - "training model is not ready, please call `engine._prepare_single_mode('train')` first." + "training model is not ready, please call `engine._prepare_program('train')` first." serial_program = self._serial_main_progs["train"] dist_main_prog = self._dist_main_progs["train"][self._cur_rank] dist_context = self._dist_contexts["train"] @@ -1097,7 +1157,7 @@ class Engine: the parameter in file storing model states of or receives a mismatch shape). Default: False. load_optimizer (bool, optional): If True, the stored optimizer - states is restored. Otherwise, the optimizer states is intialized + states is restored. Otherwise, the optimizer states is initialized from scratch. Default: False. Returns: diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index dae8cb41e66e59ed53364fdb7341d7aa3e6a51a5..b39f5e8adc5d59cdd22527c34b5632129e328413 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict + import paddle from paddle.fluid import core from .process_mesh import ProcessMesh @@ -196,15 +198,42 @@ def recompute(op): return RecomputeOperator(op) -_g_fetched_tensors = {} +# _g_fetched_tensors = {} +# def fetch(tensor, name=None): +# if name is None: +# _g_fetched_tensors[tensor.name] = tensor +# else: +# _g_fetched_tensors[name] = tensor -def fetch(tensor, name=None): - if name is None: - _g_fetched_tensors[tensor.name] = tensor +# def _get_fetches(): +# return _g_fetched_tensors + +_g_collections = {} + + +class CollectionNames(object): + FEEDS = "feeds" + FETCHES = "fetches" + + +def get_collection(name): + collection = _g_collections.get(name, None) + if collection is None: + collection = [] + _g_collections[name] = collection + return _g_collections[name] + + +def add_to_collection(collection_name, value, value_name=None): + if collection_name not in _g_collections: + _g_collections[collection_name] = [] else: - _g_fetched_tensors[name] = tensor + if value_name is not None: + _g_collections[collection_name].append((value_name, value)) + else: + _g_collections[collection_name].append((None, value)) -def _get_fetches(): - return _g_fetched_tensors +def fetch(tensor, name=None): + add_to_collection(CollectionNames.FETCHES, tensor, name) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py index bad90667df1c0da7ec2a370b7c0599a9f4f7a6f4..3691ec153923cbae0bcdad6daf8e2773d105b29d 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/engine_api.py @@ -97,7 +97,7 @@ class MLPLayer(nn.Layer): out = self.dropout(out) out = self.linear2(out) if is_fetch: - auto.fetch(out, "out") + auto.fetch(out, "my_out") return out @@ -145,6 +145,57 @@ def train(fetch): temp_dir.cleanup() +def train_callable(): + mlp = MLPLayer(hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02) + loss = paddle.nn.CrossEntropyLoss() + optimizer = paddle.optimizer.Adam(learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None) + metric = paddle.metric.Accuracy() + + strategy = auto.Strategy() + strategy.auto_mode = "semi" + + engine = auto.Engine(mlp, loss, optimizer, metric, strategy=strategy) + + # train + train_dataset = MyDataset(batch_num * batch_size) + train_dataloader = engine.dataloader(train_dataset, + batch_size=batch_size, + mode="train") + for _ in train_dataloader: + outs = engine(mode="train") + + # eval + eval_dataset2 = MyDataset(batch_size) + eval_dataloader = engine.dataloader(eval_dataset2, + batch_size=batch_size, + mode="eval") + for _ in eval_dataloader: + outs = engine(mode="eval") + + # predict + test_dataset = MyDataset(batch_size) + predict_dataloader = engine.dataloader(test_dataset, + batch_size=batch_size, + mode="predict") + for _ in predict_dataloader: + outs = engine(mode="predict") + + # save + temp_dir = tempfile.TemporaryDirectory() + model_filename = os.path.join(temp_dir.name, 'mlp') + engine.save(model_filename, training=True) + engine.load(model_filename) + temp_dir.cleanup() + + if __name__ == "__main__": train(fetch=True) train(fetch=False) + train_callable()