From 7db0dfe2f217717e3f90fc19d6815a9ec9860dc7 Mon Sep 17 00:00:00 2001 From: frankwhzhang Date: Fri, 28 Aug 2020 14:02:43 +0800 Subject: [PATCH] fix dataloader --- core/trainers/framework/runner.py | 16 +++++++++------- models/multitask/mmoe/config.yaml | 6 +++--- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index b7c67cd1..4375b726 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -174,18 +174,20 @@ class RunnerBase(object): fetch_list=metrics_varnames, return_numpy=False) + metrics = [batch_id] + metrics_rets = [ + as_numpy(metrics_tensor) + for metrics_tensor in metrics_tensors + ] + metrics.extend(metrics_rets) + if batch_id % fetch_period == 0 and batch_id != 0: - metrics = [batch_id] end_time = time.time() seconds = end_time - begin_time - metrics.extend([seconds]) + metrics_logging = metrics[:] + metrics_logging = metrics.insert(1, seconds) begin_time = end_time - metrics_rets = [ - as_numpy(metrics_tensor) - for metrics_tensor in metrics_tensors - ] - metrics.extend(metrics_rets) logging.info(metrics_format.format(*metrics)) batch_id += 1 except fluid.core.EOFException: diff --git a/models/multitask/mmoe/config.yaml b/models/multitask/mmoe/config.yaml index 0c79f9e6..354bd218 100644 --- a/models/multitask/mmoe/config.yaml +++ b/models/multitask/mmoe/config.yaml @@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe" dataset: - name: dataset_train batch_size: 5 - type: QueueDataset + type: DataLoader # or QueueDataset data_path: "{workspace}/data/train" data_converter: "{workspace}/census_reader.py" - name: dataset_infer batch_size: 5 - type: QueueDataset + type: DataLoader # or QueueDataset data_path: "{workspace}/data/train" data_converter: "{workspace}/census_reader.py" @@ -48,7 +48,7 @@ runner: save_inference_interval: 4 save_checkpoint_path: "increment" save_inference_path: "inference" - print_interval: 10 + print_interval: 1 - name: infer_runner class: infer init_model_path: "increment/1" -- GitLab