提交 1d6cfd9c 编写于 作者: L liuyuhui

Merge branch 'fix_collective_files_partition' of...

Merge branch 'fix_collective_files_partition' of https://github.com/vslyu/PaddleRec into fix_collective_files_partition
...@@ -18,11 +18,15 @@ import os ...@@ -18,11 +18,15 @@ import os
import time import time
import warnings import warnings
import numpy as np import numpy as np
import logging
import paddle.fluid as fluid import paddle.fluid as fluid
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.metric import Metric from paddlerec.core.metric import Metric
logging.basicConfig(
format='%(asctime)s - %(levelname)s: %(message)s', level=logging.INFO)
__all__ = [ __all__ = [
"RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner" "RunnerBase", "SingleRunner", "PSRunner", "CollectiveRunner", "PslibRunner"
] ]
...@@ -140,8 +144,16 @@ class RunnerBase(object): ...@@ -140,8 +144,16 @@ class RunnerBase(object):
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
if context["is_infer"]:
metrics_format.append("\t[Infer]\t{}: {{}}".format("batch"))
else:
metrics_format.append("\t[Train]\t{}: {{}}".format("batch"))
metrics_format.append("{}: {{:.2f}}s".format("time_each_interval"))
metrics_names = ["total_batch"] metrics_names = ["total_batch"]
metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items(): for name, var in metrics.items():
metrics_names.append(name) metrics_names.append(name)
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
...@@ -151,6 +163,7 @@ class RunnerBase(object): ...@@ -151,6 +163,7 @@ class RunnerBase(object):
reader = context["model"][model_dict["name"]]["model"]._data_loader reader = context["model"][model_dict["name"]]["model"]._data_loader
reader.start() reader.start()
batch_id = 0 batch_id = 0
begin_time = time.time()
scope = context["model"][model_name]["scope"] scope = context["model"][model_name]["scope"]
result = None result = None
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
...@@ -160,8 +173,8 @@ class RunnerBase(object): ...@@ -160,8 +173,8 @@ class RunnerBase(object):
program=program, program=program,
fetch_list=metrics_varnames, fetch_list=metrics_varnames,
return_numpy=False) return_numpy=False)
metrics = [batch_id]
metrics = [batch_id]
metrics_rets = [ metrics_rets = [
as_numpy(metrics_tensor) as_numpy(metrics_tensor)
for metrics_tensor in metrics_tensors for metrics_tensor in metrics_tensors
...@@ -169,7 +182,13 @@ class RunnerBase(object): ...@@ -169,7 +182,13 @@ class RunnerBase(object):
metrics.extend(metrics_rets) metrics.extend(metrics_rets)
if batch_id % fetch_period == 0 and batch_id != 0: if batch_id % fetch_period == 0 and batch_id != 0:
print(metrics_format.format(*metrics)) end_time = time.time()
seconds = end_time - begin_time
metrics_logging = metrics[:]
metrics_logging = metrics.insert(1, seconds)
begin_time = end_time
logging.info(metrics_format.format(*metrics))
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
......
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
``` ```
├── data #样例数据 ├── data #样例数据
├── train ├── train
├── train.txt #训练数据样例 ├── train.txt #训练数据样例
├── test ├── test
├── test.txt #测试数据样例 ├── test.txt #测试数据样例
├── preprocess.py #数据处理程序 ├── preprocess.py #数据处理程序
├── __init__.py ├── __init__.py
├── README.md #文档 ├── README.md #文档
├── model.py #模型文件 ├── model.py #模型文件
...@@ -44,7 +44,7 @@ Yoon Kim在论文[EMNLP 2014][Convolutional neural networks for sentence classic ...@@ -44,7 +44,7 @@ Yoon Kim在论文[EMNLP 2014][Convolutional neural networks for sentence classic
| 模型 | dev | test | | 模型 | dev | test |
| :------| :------ | :------ | :------| :------ | :------
| TextCNN | 90.75% | 92.19% | | TextCNN | 90.75% | 91.27% |
您可以直接执行以下命令下载我们分词完毕后的数据集,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt): 您可以直接执行以下命令下载我们分词完毕后的数据集,文件解压之后,senta_data目录下会存在训练数据(train.tsv)、开发集数据(dev.tsv)、测试集数据(test.tsv)以及对应的词典(word_dict.txt):
......
...@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe" ...@@ -17,12 +17,12 @@ workspace: "models/multitask/mmoe"
dataset: dataset:
- name: dataset_train - name: dataset_train
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py" data_converter: "{workspace}/census_reader.py"
- name: dataset_infer - name: dataset_infer
batch_size: 5 batch_size: 5
type: QueueDataset type: DataLoader # or QueueDataset
data_path: "{workspace}/data/train" data_path: "{workspace}/data/train"
data_converter: "{workspace}/census_reader.py" data_converter: "{workspace}/census_reader.py"
...@@ -37,7 +37,6 @@ hyper_parameters: ...@@ -37,7 +37,6 @@ hyper_parameters:
learning_rate: 0.001 learning_rate: 0.001
strategy: async strategy: async
#use infer_runner mode and modify 'phase' below if infer
mode: [train_runner, infer_runner] mode: [train_runner, infer_runner]
runner: runner:
...@@ -49,10 +48,10 @@ runner: ...@@ -49,10 +48,10 @@ runner:
save_inference_interval: 4 save_inference_interval: 4
save_checkpoint_path: "increment" save_checkpoint_path: "increment"
save_inference_path: "inference" save_inference_path: "inference"
print_interval: 10 print_interval: 1
- name: infer_runner - name: infer_runner
class: infer class: infer
init_model_path: "increment/0" init_model_path: "increment/1"
device: cpu device: cpu
phase: phase:
......
...@@ -102,9 +102,9 @@ phase: ...@@ -102,9 +102,9 @@ phase:
- name: phase1 - name: phase1
model: "{workspace}/model.py" # user-defined model model: "{workspace}/model.py" # user-defined model
dataset_name: dataloader_train # select dataset by name dataset_name: dataloader_train # select dataset by name
thread_num: 8 thread_num: 1
- name: phase2 - name: phase2
model: "{workspace}/model.py" # user-defined model model: "{workspace}/model.py" # user-defined model
dataset_name: dataset_infer # select dataset by name dataset_name: dataset_infer # select dataset by name
thread_num: 8 thread_num: 1
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册