提交 7db0dfe2 编写于 作者: F frankwhzhang

fix dataloader

上级 864d5312
......@@ -174,18 +174,20 @@ class RunnerBase(object):
fetch_list=metrics_varnames,
return_numpy=False)
metrics = [batch_id]
metrics_rets = [
as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors
]
metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0:
metrics = [batch_id]
end_time = time.time()
seconds = end_time - begin_time
metrics.extend([seconds])
metrics_logging = metrics[:]
metrics_logging = metrics.insert(1, seconds)
begin_time = end_time
metrics_rets = [
as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors
]
metrics.extend(metrics_rets)
logging.info(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
......
......@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe"
dataset:
- name: dataset_train
batch_size: 5
type: QueueDataset
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py"
- name: dataset_infer
batch_size: 5
type: QueueDataset
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py"
......@@ -48,7 +48,7 @@ runner:
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 10
print_interval: 1
- name: infer_runner
class: infer
init_model_path: "increment/1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册