提交 4a534f48 编写于 作者: L Lizhengo 提交者: Yibing Liu

simnet加入热启动训练功能,修正window10打不开utf8文件的bug (#3084)

* fix pred in simnet

* add init checkpoint in train and fix file open bug in window10

* Update README.md
上级 02abbdaf
...@@ -63,8 +63,9 @@ sh run.sh infer ...@@ -63,8 +63,9 @@ sh run.sh infer
#### 训练与验证 #### 训练与验证
用户可以基于示例数据构建训练集和开发集,可以运行下面的命令,进行模型训练和开发集验证。 用户可以基于示例数据构建训练集和开发集,可以运行下面的命令,进行模型训练和开发集验证。
```shell ```shell
sh run.sh train sh run.sh train
``` ```
用户也可以指定./run.sh中train()函数里的INIT_CHECKPOINT的值,载入训练好的模型进行热启动训练。
## 进阶使用 ## 进阶使用
### 任务定义与建模 ### 任务定义与建模
......
...@@ -4,6 +4,7 @@ SimNet reader ...@@ -4,6 +4,7 @@ SimNet reader
import logging import logging
import numpy as np import numpy as np
import codecs
class SimNetProcessor(object): class SimNetProcessor(object):
...@@ -24,7 +25,7 @@ class SimNetProcessor(object): ...@@ -24,7 +25,7 @@ class SimNetProcessor(object):
Reader with Pairwise Reader with Pairwise
""" """
if mode == "valid": if mode == "valid":
with open(self.args.valid_data_dir) as file: with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
...@@ -39,7 +40,7 @@ class SimNetProcessor(object): ...@@ -39,7 +40,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with open(self.args.test_data_dir) as file: with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
...@@ -54,7 +55,7 @@ class SimNetProcessor(object): ...@@ -54,7 +55,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with open(self.args.train_data_dir) as file: with codecs.open(self.args.train_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, pos_title, neg_title = line.strip().split("\t") query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(neg_title) == 0: if len(query) == 0 or len(pos_title) == 0 or len(neg_title) == 0:
...@@ -76,7 +77,7 @@ class SimNetProcessor(object): ...@@ -76,7 +77,7 @@ class SimNetProcessor(object):
Reader with Pointwise Reader with Pointwise
""" """
if mode == "valid": if mode == "valid":
with open(self.args.valid_data_dir) as file: with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
...@@ -91,7 +92,7 @@ class SimNetProcessor(object): ...@@ -91,7 +92,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
elif mode == "test": elif mode == "test":
with open(self.args.test_data_dir) as file: with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
...@@ -106,7 +107,7 @@ class SimNetProcessor(object): ...@@ -106,7 +107,7 @@ class SimNetProcessor(object):
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with open(self.args.train_data_dir) as file: with codecs.open(self.args.train_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int(
...@@ -131,7 +132,7 @@ class SimNetProcessor(object): ...@@ -131,7 +132,7 @@ class SimNetProcessor(object):
""" """
get infer reader get infer reader
""" """
with open(self.args.infer_data_dir, "r") as file: with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title = line.strip().split("\t") query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0: if len(query) == 0 or len(title) == 0:
...@@ -149,7 +150,7 @@ class SimNetProcessor(object): ...@@ -149,7 +150,7 @@ class SimNetProcessor(object):
""" """
get infer data get infer data
""" """
with open(self.args.infer_data_dir, "r") as file: with codecs.open(self.args.infer_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title = line.strip().split("\t") query, title = line.strip().split("\t")
if len(query) == 0 or len(title) == 0: if len(query) == 0 or len(title) == 0:
...@@ -163,7 +164,7 @@ class SimNetProcessor(object): ...@@ -163,7 +164,7 @@ class SimNetProcessor(object):
""" """
if self.valid_label.size == 0: if self.valid_label.size == 0:
labels = [] labels = []
with open(self.args.valid_data_dir, "r") as f: with codecs.open(self.args.valid_data_dir, "r", "utf-8") as f:
for line in f: for line in f:
labels.append([int(line.strip().split("\t")[-1])]) labels.append([int(line.strip().split("\t")[-1])])
self.valid_label = np.array(labels) self.valid_label = np.array(labels)
...@@ -175,7 +176,7 @@ class SimNetProcessor(object): ...@@ -175,7 +176,7 @@ class SimNetProcessor(object):
""" """
if self.test_label.size == 0: if self.test_label.size == 0:
labels = [] labels = []
with open(self.args.test_data_dir, "r") as f: with codecs.open(self.args.test_data_dir, "r", "utf-8") as f:
for line in f: for line in f:
labels.append([int(line.strip().split("\t")[-1])]) labels.append([int(line.strip().split("\t")[-1])])
self.test_label = np.array(labels) self.test_label = np.array(labels)
......
...@@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(__doc__) ...@@ -28,7 +28,7 @@ parser = argparse.ArgumentParser(__doc__)
model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.") model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("config_path", str, None, model_g.add_arg("config_path", str, None,
"Path to the json file for EmoTect model config.") "Path to the json file for EmoTect model config.")
model_g.add_arg("init_checkpoint", str, "examples/cnn_pointwise.json", model_g.add_arg("init_checkpoint", str, None,
"Init checkpoint to resume training from.") "Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints") model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints")
model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise") model_g.add_arg("task_mode", str, None, "task mode: pairwise or pointwise")
...@@ -163,11 +163,16 @@ def train(conf_dict, args): ...@@ -163,11 +163,16 @@ def train(conf_dict, args):
infer_program = fluid.default_main_program().clone(for_test=True) infer_program = fluid.default_main_program().clone(for_test=True)
avg_cost = loss.compute(pred, label) avg_cost = loss.compute(pred, label)
avg_cost.persistable = True avg_cost.persistable = True
# operate Optimization # operate Optimization
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
executor = fluid.Executor(place) executor = fluid.Executor(place)
executor.run(fluid.default_startup_program()) executor.run(fluid.default_startup_program())
if args.init_checkpoint is not None:
utils.init_checkpoint(executor, args.init_checkpoint,
fluid.default_startup_program())
# Get and run executor # Get and run executor
parallel_executor = fluid.ParallelExecutor( parallel_executor = fluid.ParallelExecutor(
use_cuda=args.use_cuda, use_cuda=args.use_cuda,
......
...@@ -8,26 +8,28 @@ import sys ...@@ -8,26 +8,28 @@ import sys
import re import re
import os import os
import six import six
import codecs
import numpy as np import numpy as np
import logging import logging
import logging.handlers import logging.handlers
import paddle.fluid as fluid
""" """
******functions for file processing****** ******functions for file processing******
""" """
def load_vocab(file_path): def load_vocab(file_path):
""" """
load the given vocabulary load the given vocabulary
""" """
vocab = {} vocab = {}
if not os.path.isfile(file_path): if six.PY3:
raise ValueError("vocabulary dose not exist under %s" % file_path) f = open(file_path, "r", encoding="utf-8")
with open(file_path, 'r') as f: else:
for line in f: f = open(file_path, "r")
items = line.strip('\n').split("\t") for line in f:
if items[0] not in vocab: items = line.strip("\n").split("\t")
vocab[items[0]] = int(items[1]) if items[0] not in vocab:
vocab[items[0]] = int(items[1])
vocab["<unk>"] = 0 vocab["<unk>"] = 0
return vocab return vocab
...@@ -43,9 +45,9 @@ def get_result_file(args): ...@@ -43,9 +45,9 @@ def get_result_file(args):
result_file: merge sample and predict result result_file: merge sample and predict result
""" """
with open(args.test_data_dir, "r") as test_file: with codecs.open(args.test_data_dir, "r", "utf-8") as test_file:
with open("predictions.txt", "r") as predictions_file: with codecs.open("predictions.txt", "r", "utf-8") as predictions_file:
with open(args.test_result_path, "w") as test_result_file: with codecs.open(args.test_result_path, "w", "utf-8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file] test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file] predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions): for test_data, prediction in zip(test_datas, predictions):
...@@ -277,3 +279,24 @@ def deal_preds_of_mmdnn(conf_dict, preds): ...@@ -277,3 +279,24 @@ def deal_preds_of_mmdnn(conf_dict, preds):
return get_sigmoid(preds) return get_sigmoid(preds)
else: else:
return get_softmax(preds) return get_softmax(preds)
def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
init checkpoint
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册