提交 4a534f48 编写于 作者: L Lizhengo 提交者: Yibing Liu

simnet加入热启动训练功能,修正window10打不开utf8文件的bug (#3084)

* fix pred in simnet

* add init checkpoint in train and fix file open bug in window10

* Update README.md
上级 02abbdaf
......@@ -63,8 +63,9 @@ sh run.sh infer
#### 训练与验证
用户可以基于示例数据构建训练集和开发集,可以运行下面的命令,进行模型训练和开发集验证。
```shell
sh run.sh train
sh run.sh train
```
用户也可以指定./run.sh中train()函数里的INIT_CHECKPOINT的值,载入训练好的模型进行热启动训练。
## 进阶使用
### 任务定义与建模
......
......@@ -4,6 +4,7 @@ SimNet reader
import logging
import numpy as np
import codecs
class SimNetProcessor(object):
......@@ -24,7 +25,7 @@ class SimNetProcessor(object):
Reader with Pairwise
"""
if mode == "valid":
with open(self.args.valid_data_dir) as file:
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
......@@ -39,7 +40,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
elif mode == "test":
with open(self.args.test_data_dir) as file:
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
......@@ -54,7 +55,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
with open(self.args.train_data_dir) as file:
with codecs.open(self.args.train_data_dir, "r", "utf-8") as file:
for line in file:
query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(neg_title) == 0:
......@@ -76,7 +77,7 @@ class SimNetProcessor(object):
Reader with Pointwise
"""
if mode == "valid":
with open(self.args.valid_data_dir) as file:
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
......@@ -91,7 +92,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
elif mode == "test":
with open(self.args.test_data_dir) as file:
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
......@@ -106,7 +107,7 @@ class SimNetProcessor(object):
title = [0]
yield [query, title]
else:
with open(self.args.train_data_dir) as file:
with codecs.open(self.args.train_data_dir, "r", "utf-8") as file:
for line in file:
query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
......@@ -131,7 +132,7 @@ class SimNetProcessor(object):
"""
get infer reader
"""
with open(self.args.infer_data_dir, "r") as file:
with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file:
for line in file:
query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0:
......@@ -149,7 +150,7 @@ class SimNetProcessor(object):
"""
get infer data
"""
with open(self.args.infer_data_dir, "r") as file:
with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file:
for line in file:
query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0:
......@@ -163,7 +164,7 @@ class SimNetProcessor(object):
"""
if self.valid_label.size == 0:
labels = []
with open(self.args.valid_data_dir, "r") as f:
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as f:
for line in f:
labels.append([int(line.strip().split("\t")[-1])])
self.valid_label = np.array(labels)
......@@ -175,7 +176,7 @@ class SimNetProcessor(object):
"""
if self.test_label.size == 0:
labels = []
with open(self.args.test_data_dir, "r") as f:
with codecs.open(self.args.test_data_dir, "r", "utf-8") as f:
for line in f:
labels.append([int(line.strip().split("\t")[-1])])
self.test_label = np.array(labels)
......
......@@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(__doc__)
model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("config_path", str, None,
"Path to the json file for EmoTect model config.")
model_g.add_arg("init_checkpoint", str, "examples/cnn_pointwise.json",
model_g.add_arg("init_checkpoint", str, None,
"Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints")
model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise")
......@@ -163,11 +163,16 @@ def train(conf_dict, args):
infer_program = fluid.default_main_program().clone(for_test=True)
avg_cost = loss.compute(pred, label)
avg_cost.persistable = True
# operate Optimization
optimizer.ops(avg_cost)
executor = fluid.Executor(place)
executor.run(fluid.default_startup_program())
if args.init_checkpoint is not None:
utils.init_checkpoint(executor, args.init_checkpoint,
fluid.default_startup_program())
# Get and run executor
parallel_executor = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
......
......@@ -8,26 +8,28 @@ import sys
import re
import os
import six
import codecs
import numpy as np
import logging
import logging.handlers
import paddle.fluid as fluid
"""
******functions for file processing******
"""
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
if not os.path.isfile(file_path):
raise ValueError("vocabulary dose not exist under %s" % file_path)
with open(file_path, 'r') as f:
for line in f:
items = line.strip('\n').split("\t")
if items[0] not in vocab:
vocab[items[0]] = int(items[1])
if six.PY3:
f = open(file_path, "r", encoding="utf-8")
else:
f = open(file_path, "r")
for line in f:
items = line.strip("\n").split("\t")
if items[0] not in vocab:
vocab[items[0]] = int(items[1])
vocab["<unk>"] = 0
return vocab
......@@ -43,9 +45,9 @@ def get_result_file(args):
result_file: merge sample and predict result
"""
with open(args.test_data_dir, "r") as test_file:
with open("predictions.txt", "r") as predictions_file:
with open(args.test_result_path, "w") as test_result_file:
with codecs.open(args.test_data_dir, "r", "utf-8") as test_file:
with codecs.open("predictions.txt", "r", "utf-8") as predictions_file:
with codecs.open(args.test_result_path, "w", "utf-8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions):
......@@ -277,3 +279,24 @@ def deal_preds_of_mmdnn(conf_dict, preds):
return get_sigmoid(preds)
else:
return get_softmax(preds)
def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
init checkpoint
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册