提交 1d6cfd9c 编写于 作者: L liuyuhui

Merge branch 'fix_collective_files_partition' of...

Merge branch 'fix_collective_files_partition' of https://github.com/vslyu/PaddleRec into fix_collective_files_partition
......@@ -18,11 +18,15 @@ import os
import time
import warnings
import numpy as np
import logging
import paddle.fluid as fluid
from paddlerec.core.utils import envs
from paddlerec.core.metric import Metric
logging.basicConfig(
format='%(asctime)s - %(levelname)s: %(message)s', level=logging.INFO)
__all__ = [
"RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner"
]
......@@ -140,8 +144,16 @@ class RunnerBase(object):
metrics_varnames = []
metrics_format = []
if context["is_infer"]:
metrics_format.append("\t[Infer]\t{}: {{}}".format("batch"))
else:
metrics_format.append("\t[Train]\t{}: {{}}".format("batch"))
metrics_format.append("{}: {{:.2f}}s".format("time_each_interval"))
metrics_names = ["total_batch"]
metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items():
metrics_names.append(name)
metrics_varnames.append(var.name)
......@@ -151,6 +163,7 @@ class RunnerBase(object):
reader = context["model"][model_dict["name"]]["model"]._data_loader
reader.start()
batch_id = 0
begin_time = time.time()
scope = context["model"][model_name]["scope"]
result = None
with fluid.scope_guard(scope):
......@@ -160,8 +173,8 @@ class RunnerBase(object):
program=program,
fetch_list=metrics_varnames,
return_numpy=False)
metrics = [batch_id]
metrics = [batch_id]
metrics_rets = [
as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors
......@@ -169,7 +182,13 @@ class RunnerBase(object):
metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
end_time = time.time()
seconds = end_time - begin_time
metrics_logging = metrics[:]
metrics_logging = metrics.insert(1, seconds)
begin_time = end_time
logging.info(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
......
......@@ -4,11 +4,11 @@
```
├── data #样例数据
├── train
├── train.txt #训练数据样例
├── test
├── test.txt #测试数据样例
├── preprocess.py #数据处理程序
├── train
├── train.txt #训练数据样例
├── test
├── test.txt #测试数据样例
├── preprocess.py #数据处理程序
├── __init__.py
├── README.md #文档
├── model.py #模型文件
......@@ -44,7 +44,7 @@ Yoon Kim在论文[EMNLP 2014][Convolutional neural networks for sentence classic
| 模型 | dev | test |
| :------| :------ | :------
| TextCNN | 90.75% | 92.19% |
| TextCNN | 90.75% | 91.27% |
您可以直接执行以下命令下载我们分词完毕后的数据集,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt):
......
......@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe"
dataset:
- name: dataset_train
batch_size: 5
type: QueueDataset
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py"
- name: dataset_infer
batch_size: 5
type: QueueDataset
type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py"
......@@ -37,7 +37,6 @@ hyper_parameters:
learning_rate: 0.001
strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: [train_runner, infer_runner]
runner:
......@@ -49,10 +48,10 @@ runner:
save_inference_interval: 4
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 10
print_interval: 1
- name: infer_runner
class: infer
init_model_path: "increment/0"
init_model_path: "increment/1"
device: cpu
phase:
......
......@@ -102,9 +102,9 @@ phase:
- name: phase1
model: "{workspace}/model.py" # user-defined model
dataset_name: dataloader_train # select dataset by name
thread_num: 8
thread_num: 1
- name: phase2
model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name
thread_num: 8
thread_num: 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册