提交 7db0dfe2 编写于 作者: F frankwhzhang

fix dataloader

上级 864d5312
...@@ -174,18 +174,20 @@ class RunnerBase(object): ...@@ -174,18 +174,20 @@ class RunnerBase(object):
fetch_list=metrics_varnames, fetch_list=metrics_varnames,
return_numpy=False) return_numpy=False)
if batch_id % fetch_period == 0 and batch_id != 0:
metrics = [batch_id] metrics = [batch_id]
end_time = time.time()
seconds = end_time - begin_time
metrics.extend([seconds])
begin_time = end_time
metrics_rets = [ metrics_rets = [
as_numpy(metrics_tensor) as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors for metrics_tensor in metrics_tensors
] ]
metrics.extend(metrics_rets) metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0:
end_time = time.time()
seconds = end_time - begin_time
metrics_logging = metrics[:]
metrics_logging = metrics.insert(1, seconds)
begin_time = end_time
logging.info(metrics_format.format(*metrics)) logging.info(metrics_format.format(*metrics))
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
......
...@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe" ...@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe"
dataset: dataset:
- name: dataset_train - name: dataset_train
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py" data_converter: "{workspace}/census_reader.py"
- name: dataset_infer - name: dataset_infer
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py" data_converter: "{workspace}/census_reader.py"
...@@ -48,7 +48,7 @@ runner: ...@@ -48,7 +48,7 @@ runner:
save_inference_interval: 4 save_inference_interval: 4
save_checkpoint_path: "increment" save_checkpoint_path: "increment"
save_inference_path: "inference" save_inference_path: "inference"
print_interval: 10 print_interval: 1
- name: infer_runner - name: infer_runner
class: infer class: infer
init_model_path: "increment/1" init_model_path: "increment/1"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册