未验证 提交 d7f8866b 编写于 作者: 0 0YuanZhang0 提交者: GitHub

Fix windows (#3646)

* fix_dgu_ade_dnet

* fix_readme
上级 e2e782c5
...@@ -82,7 +82,7 @@ label_data(第二阶段finetuning数据集) ...@@ -82,7 +82,7 @@ label_data(第二阶段finetuning数据集)
    数据集、相关模型下载:     数据集、相关模型下载:
``` ```
cd ade && bash prepare_data_and_model.sh python ade/prepare_data_and_model.py
``` ```
    数据路径:data/input/data/     数据路径:data/input/data/
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import shutil
import urllib
import sys
import io
import os
URLLIB=urllib
if sys.version_info >= (3, 0):
URLLIB=urllib.request
DATA_MODEL_PATH = {"DATA_PATH": "https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_dataset-1.0.0.tar.gz",
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.2.0.0.tar.gz"}
PATH_MAP = {'DATA_PATH': "./data/input",
'TRAINED_MODEL': './data/saved_models'}
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path = dir_name)
return True
except Exception as e:
print(e)
return False
def download_model_and_data():
print("Downloading ade data, pretrain model and trained models......")
print("This process is quite long, please wait patiently............")
for path in ['./data/input/data', './data/saved_models/trained_models']:
if not os.path.exists(path):
continue
shutil.rmtree(path)
for path_key in DATA_MODEL_PATH:
filename = os.path.basename(DATA_MODEL_PATH[path_key])
URLLIB.urlretrieve(DATA_MODEL_PATH[path_key], os.path.join("./", filename))
state = un_tar(filename, PATH_MAP[path_key])
if not state:
print("Tar %s error....." % path_key)
return False
os.remove(filename)
return True
if __name__ == "__main__":
state = download_model_and_data()
if not state:
exit(1)
print("Downloading data and models sucess......")
#!/bin/bash
#check data directory
cd ..
echo "Start download data and models.............."
if [ ! -d "data" ]; then
echo "Directory data does not exist, make new data directory"
mkdir data
fi
cd data
#check configure file
if [ ! -d "config" ]; then
echo "config directory not exist........"
exit 255
else
if [ ! -f "config/ade.yaml" ]; then
echo "config file dgu.yaml has been lost........"
exit 255
fi
fi
#check and download input data
if [ ! -d "input" ]; then
echo "Directory input does not exist, make new input directory"
mkdir input
fi
cd input
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_dataset-1.0.0.tar.gz
tar -zxvf auto_dialogue_evaluation_dataset-1.0.0.tar.gz
rm auto_dialogue_evaluation_dataset-1.0.0.tar.gz
cd ..
#check and download pretrain model
if [ ! -d "pretrain_model" ]; then
echo "Directory pretrain_model does not exist, make new pretrain_model directory"
mkdir pretrain_model
fi
#check and download inferenece model
if [ ! -d "inference_models" ]; then
echo "Directory inferenece_model does not exist, make new inferenece_model directory"
mkdir inference_models
fi
#check output
if [ ! -d "output" ]; then
echo "Directory output does not exist, make new output directory"
mkdir output
fi
#check saved model
if [ ! -d "saved_models" ]; then
echo "Directory saved_models does not exist, make new saved_models directory"
mkdir saved_models
fi
cd saved_models
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.2.0.0.tar.gz
tar -xvf auto_dialogue_evaluation_models.2.0.0.tar.gz
rm auto_dialogue_evaluation_models.2.0.0.tar.gz
echo "Finish.............."
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""Reader for auto dialogue evaluation""" """Reader for auto dialogue evaluation"""
import io
import sys import sys
import time import time
import random import random
...@@ -34,12 +36,12 @@ class DataProcessor(object): ...@@ -34,12 +36,12 @@ class DataProcessor(object):
"""load examples""" """load examples"""
examples = [] examples = []
index = 0 index = 0
with open(self.data_file, 'r') as fr: fr = io.open(self.data_file, 'r', encoding="utf8")
for line in fr: for line in fr:
if index !=0 and index % 100 == 0: if index !=0 and index % 100 == 0:
print("processing data: %d" % index) print("processing data: %d" % index)
index += 1 index += 1
examples.append(line.strip()) examples.append(line.strip())
return examples return examples
def get_num_examples(self, phase): def get_num_examples(self, phase):
...@@ -47,7 +49,7 @@ class DataProcessor(object): ...@@ -47,7 +49,7 @@ class DataProcessor(object):
if phase not in ['train', 'dev', 'test']: if phase not in ['train', 'dev', 'test']:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].") "Unknown phase, which should be in ['train', 'dev', 'test'].")
count = len(open(self.data_file,'rU').readlines()) count = len(io.open(self.data_file, 'r', encoding="utf8").readlines())
self.num_examples[phase] = count self.num_examples[phase] = count
return self.num_examples[phase] return self.num_examples[phase]
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,6 +18,7 @@ from __future__ import division ...@@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import io
import sys import sys
import argparse import argparse
import json import json
...@@ -38,8 +40,8 @@ class JsonConfig(object): ...@@ -38,8 +40,8 @@ class JsonConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file) config_dict = json.load(json_file)
except: except:
raise IOError("Error in parsing bert model config file '%s'" % raise IOError("Error in parsing bert model config file '%s'" %
config_path) config_path)
...@@ -214,9 +216,9 @@ class PDConfig(object): ...@@ -214,9 +216,9 @@ class PDConfig(object):
raise Warning("the json file %s does not exist." % file_path) raise Warning("the json file %s does not exist." % file_path)
return return
with open(file_path, "r") as fin: fin = io.open(file_path, "r", encoding="utf8")
self.json_config = json.loads(fin.read()) self.json_config = json.loads(fin.read())
fin.close() fin.close()
if fuse_args: if fuse_args:
for name in self.json_config: for name in self.json_config:
...@@ -238,9 +240,9 @@ class PDConfig(object): ...@@ -238,9 +240,9 @@ class PDConfig(object):
raise Warning("the yaml file %s does not exist." % file_path) raise Warning("the yaml file %s does not exist." % file_path)
return return
with open(file_path, "r") as fin: fin = io.open(file_path, "r", encoding="utf8")
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close() fin.close()
if fuse_args: if fuse_args:
for name in self.yaml_config: for name in self.yaml_config:
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""evaluation metrics""" """evaluation metrics"""
import io
import os import os
import sys import sys
import numpy as np import numpy as np
...@@ -24,22 +26,22 @@ from ade.utils.configure import PDConfig ...@@ -24,22 +26,22 @@ from ade.utils.configure import PDConfig
def do_eval(args): def do_eval(args):
"""evaluate metrics""" """evaluate metrics"""
labels = [] labels = []
with open(args.evaluation_file, 'r') as fr: fr = io.open(args.evaluation_file, 'r', encoding="utf8")
for line in fr: for line in fr:
tokens = line.strip().split('\t') tokens = line.strip().split('\t')
assert len(tokens) == 3 assert len(tokens) == 3
label = int(tokens[2]) label = int(tokens[2])
labels.append(label) labels.append(label)
scores = [] scores = []
with open(args.output_prediction_file, 'r') as fr: fr = io.open(args.output_prediction_file, 'r', encoding="utf8")
for line in fr: for line in fr:
tokens = line.strip().split('\t') tokens = line.strip().split('\t')
assert len(tokens) == 2 assert len(tokens) == 2
score = tokens[1].strip("[]").split() score = tokens[1].strip("[]").split()
score = np.array(score) score = np.array(score)
score = score.astype(np.float64) score = score.astype(np.float64)
scores.append(score) scores.append(score)
if args.loss_type == 'CLS': if args.loss_type == 'CLS':
recall_dict = evaluate.evaluate_Recall(list(zip(scores, labels))) recall_dict = evaluate.evaluate_Recall(list(zip(scores, labels)))
......
...@@ -18,7 +18,6 @@ import sys ...@@ -18,7 +18,6 @@ import sys
import six import six
import numpy as np import numpy as np
import time import time
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,13 +13,12 @@ ...@@ -12,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""predict auto dialogue evaluation task""" """predict auto dialogue evaluation task"""
import io
import os import os
import sys import sys
import six import six
import time import time
import numpy as np import numpy as np
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -109,9 +109,9 @@ def do_predict(args): ...@@ -109,9 +109,9 @@ def do_predict(args):
scores = scores[: num_test_examples] scores = scores[: num_test_examples]
print("Write the predicted results into the output_prediction_file") print("Write the predicted results into the output_prediction_file")
with open(args.output_prediction_file, 'w') as fw: fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
for index, score in enumerate(scores): for index, score in enumerate(scores):
fw.write("%s\t%s\n" % (index, score)) fw.write("%s\t%s\n" % (index, score))
print("finish........................................") print("finish........................................")
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,13 +13,12 @@ ...@@ -12,13 +13,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""train auto dialogue evaluation task""" """train auto dialogue evaluation task"""
import io
import os import os
import sys import sys
import six import six
import time import time
import numpy as np import numpy as np
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -76,8 +76,7 @@ def do_train(args): ...@@ -76,8 +76,7 @@ def do_train(args):
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0'))) place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
else: else:
dev_count = int( dev_count = int(os.environ.get('CPU_NUM', 1))
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
place = fluid.CPUPlace() place = fluid.CPUPlace()
processor = reader.DataProcessor( processor = reader.DataProcessor(
...@@ -115,9 +114,9 @@ def do_train(args): ...@@ -115,9 +114,9 @@ def do_train(args):
if args.word_emb_init: if args.word_emb_init:
print("start loading word embedding init ...") print("start loading word embedding init ...")
if six.PY2: if six.PY2:
word_emb = np.array(pickle.load(open(args.word_emb_init, 'rb'))).astype('float32') word_emb = np.array(pickle.load(io.open(args.word_emb_init, 'rb'))).astype('float32')
else: else:
word_emb = np.array(pickle.load(open(args.word_emb_init, 'rb'), encoding="bytes")).astype('float32') word_emb = np.array(pickle.load(io.open(args.word_emb_init, 'rb'), encoding="bytes")).astype('float32')
set_word_embedding(word_emb, place) set_word_embedding(word_emb, place)
print("finish init word embedding ...") print("finish init word embedding ...")
......
...@@ -63,7 +63,7 @@ SWDA:Switchboard Dialogue Act Corpus; ...@@ -63,7 +63,7 @@ SWDA:Switchboard Dialogue Act Corpus;
    数据集、相关模型下载:     数据集、相关模型下载:
``` ```
cd dgu && bash prepare_data_and_model.sh python dgu/prepare_data_and_model.py
``` ```
    数据路径:data/input/data     数据路径:data/input/data
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -18,6 +19,7 @@ from __future__ import division ...@@ -18,6 +19,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import os import os
import io
import sys import sys
import six import six
import json import json
...@@ -33,8 +35,8 @@ class BertConfig(object): ...@@ -33,8 +35,8 @@ class BertConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing bert model config file '%s'" % raise IOError("Error in parsing bert model config file '%s'" %
config_path) config_path)
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -14,6 +15,7 @@ ...@@ -14,6 +15,7 @@
"""evaluate task metrics""" """evaluate task metrics"""
import sys import sys
import io
class EvalDA(object): class EvalDA(object):
...@@ -33,18 +35,18 @@ class EvalDA(object): ...@@ -33,18 +35,18 @@ class EvalDA(object):
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
with open(self.refer_file, 'r') as fr: fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[1] label = line.rstrip('\n').split('\t')[1]
refer_label.append(int(label)) refer_label.append(int(label))
idx = 0 idx = 0
with open(self.pred_file, 'r') as fr: fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
tag_id = int(elems[1]) tag_id = int(elems[1])
pred_label.append(tag_id) pred_label.append(tag_id)
return pred_label, refer_label return pred_label, refer_label
def evaluate(self): def evaluate(self):
...@@ -78,18 +80,18 @@ class EvalATISIntent(object): ...@@ -78,18 +80,18 @@ class EvalATISIntent(object):
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
with open(self.refer_file, 'r') as fr: fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[0] label = line.rstrip('\n').split('\t')[0]
refer_label.append(int(label)) refer_label.append(int(label))
idx = 0 idx = 0
with open(self.pred_file, 'r') as fr: fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
tag_id = int(elems[1]) tag_id = int(elems[1])
pred_label.append(tag_id) pred_label.append(tag_id)
return pred_label, refer_label return pred_label, refer_label
def evaluate(self): def evaluate(self):
...@@ -123,18 +125,18 @@ class EvalATISSlot(object): ...@@ -123,18 +125,18 @@ class EvalATISSlot(object):
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
with open(self.refer_file, 'r') as fr: fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
labels = line.rstrip('\n').split('\t')[1].split() labels = line.rstrip('\n').split('\t')[1].split()
labels = [int(l) for l in labels] labels = [int(l) for l in labels]
refer_label.append(labels) refer_label.append(labels)
with open(self.pred_file, 'r') as fr: fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
if len(line.split('\t')) != 2 or not line[0].isdigit(): if len(line.split('\t')) != 2 or not line[0].isdigit():
continue continue
labels = line.rstrip('\n').split('\t')[1].split()[1:] labels = line.rstrip('\n').split('\t')[1].split()[1:]
labels = [int(l) for l in labels] labels = [int(l) for l in labels]
pred_label.append(labels) pred_label.append(labels)
pred_label_equal = [] pred_label_equal = []
refer_label_equal = [] refer_label_equal = []
assert len(refer_label) == len(pred_label) assert len(refer_label) == len(pred_label)
...@@ -208,19 +210,19 @@ class EvalUDC(object): ...@@ -208,19 +210,19 @@ class EvalUDC(object):
""" """
data = [] data = []
refer_label = [] refer_label = []
with open(self.refer_file, 'r') as fr: fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
label = line.rstrip('\n').split('\t')[0] label = line.rstrip('\n').split('\t')[0]
refer_label.append(label) refer_label.append(label)
idx = 0 idx = 0
with open(self.pred_file, 'r') as fr: fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
elems = line.rstrip('\n').split('\t') elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit(): if len(elems) != 2 or not elems[0].isdigit():
continue continue
match_prob = elems[1] match_prob = elems[1]
data.append((float(match_prob), int(refer_label[idx]))) data.append((float(match_prob), int(refer_label[idx])))
idx += 1 idx += 1
return data return data
def get_p_at_n_in_m(self, data, n, m, ind): def get_p_at_n_in_m(self, data, n, m, ind):
...@@ -281,17 +283,17 @@ class EvalDSTC2(object): ...@@ -281,17 +283,17 @@ class EvalDSTC2(object):
""" """
pred_label = [] pred_label = []
refer_label = [] refer_label = []
with open(self.refer_file, 'r') as fr: fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip('\n') line = line.strip('\n')
labels = [int(l) for l in line.split('\t')[-1].split()] labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels))) labels = sorted(list(set(labels)))
refer_label.append(" ".join([str(l) for l in labels])) refer_label.append(" ".join([str(l) for l in labels]))
all_pred = [] all_pred = []
with open(self.pred_file, 'r') as fr: fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip('\n') line = line.strip('\n')
all_pred.append(line) all_pred.append(line)
all_pred = all_pred[len(all_pred) - len(refer_label):] all_pred = all_pred[len(all_pred) - len(refer_label):]
for line in all_pred: for line in all_pred:
labels = [int(l) for l in line.split('\t')[-1].split()] labels = [int(l) for l in line.split('\t')[-1].split()]
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import shutil
import urllib
import sys
import io
import os
URLLIB=urllib
if sys.version_info >= (3, 0):
URLLIB=urllib.request
DATA_MODEL_PATH = {"DATA_PATH": "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz",
"PRETRAIN_MODEL": "https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz",
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/dgu_models_2.0.0.tar.gz"}
PATH_MAP = {'DATA_PATH': "./data/input",
'PRETRAIN_MODEL': './data/pretrain_model',
'TRAINED_MODEL': './data/saved_models'}
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path = dir_name)
return True
except Exception as e:
print(e)
return False
def download_model_and_data():
print("Downloading dgu data, pretrain model and trained models......")
print("This process is quite long, please wait patiently............")
for path in ['./data/input/data', './data/pretrain_model/uncased_L-12_H-768_A-12', './data/saved_models/trained_models']:
if not os.path.exists(path):
continue
shutil.rmtree(path)
for path_key in DATA_MODEL_PATH:
filename = os.path.basename(DATA_MODEL_PATH[path_key])
URLLIB.urlretrieve(DATA_MODEL_PATH[path_key], os.path.join("./", filename))
state = un_tar(filename, PATH_MAP[path_key])
if not state:
print("Tar %s error....." % path_key)
return False
os.remove(filename)
return True
if __name__ == "__main__":
state = download_model_and_data()
if not state:
exit(1)
print("Downloading data and models sucess......")
#!/bin/bash
#check data directory
cd ..
echo "Start download data and models.............."
if [ ! -d "data" ]; then
echo "Directory data does not exist, make new data directory"
mkdir data
fi
cd data
#check configure file
if [ ! -d "config" ]; then
echo "config directory not exist........"
exit 255
else
if [ ! -f "config/dgu.yaml" ]; then
echo "config file dgu.yaml has been lost........"
exit 255
fi
fi
#check and download input data
if [ ! -d "input" ]; then
echo "Directory input does not exist, make new input directory"
mkdir input
fi
cd input
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz
tar -xvf dmtk_data_1.0.0.tar.gz
rm dmtk_data_1.0.0.tar.gz
cd ..
#check and download pretrain model
if [ ! -d "pretrain_model" ]; then
echo "Directory pretrain_model does not exist, make new pretrain_model directory"
mkdir pretrain_model
fi
cd pretrain_model
wget --no-check-certificate https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz
tar -xvf uncased_L-12_H-768_A-12.tar.gz
rm uncased_L-12_H-768_A-12.tar.gz
cd ..
#check and download inferenece model
if [ ! -d "inference_models" ]; then
echo "Directory inferenece_model does not exist, make new inferenece_model directory"
mkdir inference_models
fi
#check output
if [ ! -d "output" ]; then
echo "Directory output does not exist, make new output directory"
mkdir output
fi
#check saved model
if [ ! -d "saved_models" ]; then
echo "Directory saved_models does not exist, make new saved_models directory"
mkdir saved_models
fi
cd saved_models
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/dgu_models_2.0.0.tar.gz
tar -xvf dgu_models_2.0.0.tar.gz
rm dgu_models_2.0.0.tar.gz
cd ..
echo "Finish.............."
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""data reader""" """data reader"""
import os import os
import io
import csv import csv
import sys import sys
import types import types
...@@ -107,12 +109,12 @@ class DataProcessor(object): ...@@ -107,12 +109,12 @@ class DataProcessor(object):
@classmethod @classmethod
def _read_tsv(cls, input_file, quotechar=None): def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: f = io.open(input_file, "r", encoding="utf8")
reader = csv.reader(f, delimiter="\t", quotechar=quotechar) reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = [] lines = []
for line in reader: for line in reader:
lines.append(line) lines.append(line)
return lines return lines
def get_num_examples(self, phase): def get_num_examples(self, phase):
"""Get number of examples for train, dev or test.""" """Get number of examples for train, dev or test."""
......
scripts:运行数据处理脚本目录, 将官方公开数据集转换成模型所需训练数据格式 scripts:运行数据处理脚本目录, 将官方公开数据集转换成模型所需训练数据格式
运行命令: 运行命令:
sh run_build_data.sh [udc|swda|mrda|atis|dstc2] python run_build_data.py [udc|swda|mrda|atis|dstc2]
1)、生成MATCHING任务所需要的训练集、开发集、测试集时: 1)、生成MATCHING任务所需要的训练集、开发集、测试集时:
sh run_build_data.sh udc python run_build_data.py udc
生成数据在dialogue_general_understanding/data/input/data/udc 生成数据在dialogue_general_understanding/data/input/data/udc
2)、生成DA任务所需要的训练集、开发集、测试集时: 2)、生成DA任务所需要的训练集、开发集、测试集时:
sh run_build_data.sh swda python run_build_data.py swda
sh run_build_data.sh mrda python run_build_data.py mrda
生成数据分别在dialogue_general_understanding/data/input/data/swda和dialogue_general_understanding/data/input/data/mrda 生成数据分别在dialogue_general_understanding/data/input/data/swda和dialogue_general_understanding/data/input/data/mrda
3)、生成DST任务所需的训练集、开发集、测试集时: 3)、生成DST任务所需的训练集、开发集、测试集时:
sh run_build_data.sh dstc2 python run_build_data.py dstc2
生成数据分别在dialogue_general_understanding/data/input/data/dstc2 生成数据分别在dialogue_general_understanding/data/input/data/dstc2
4)、生成意图解析, 槽位识别任务所需训练集、开发集、测试集时: 4)、生成意图解析, 槽位识别任务所需训练集、开发集、测试集时:
sh run_build_data.sh atis python run_build_data.py atis
生成槽位识别数据在dialogue_general_understanding/data/input/data/atis/atis_slot 生成槽位识别数据在dialogue_general_understanding/data/input/data/atis/atis_slot
生成意图识别数据在dialogue_general_understanding/data/input/data/atis/atis_intent 生成意图识别数据在dialogue_general_understanding/data/input/data/atis/atis_intent
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -18,6 +19,7 @@ import json ...@@ -18,6 +19,7 @@ import json
import sys import sys
import csv import csv
import os import os
import io
import re import re
...@@ -51,8 +53,8 @@ class ATIS(object): ...@@ -51,8 +53,8 @@ class ATIS(object):
os.makedirs(self.out_intent_dir) os.makedirs(self.out_intent_dir)
src_examples = [] src_examples = []
json_file = os.path.join(self.src_dir, "%s.json" % data_type) json_file = os.path.join(self.src_dir, "%s.json" % data_type)
with open(json_file, 'r') as load_f: load_f = io.open(json_file, 'r', encoding="utf8")
json_dict = json.load(load_f) json_dict = json.load(load_f)
examples = json_dict['rasa_nlu_data']['common_examples'] examples = json_dict['rasa_nlu_data']['common_examples']
for example in examples: for example in examples:
text = example.get('text') text = example.get('text')
...@@ -66,62 +68,62 @@ class ATIS(object): ...@@ -66,62 +68,62 @@ class ATIS(object):
parser intent dataset parser intent dataset
""" """
out_filename = "%s/%s.txt" % (self.out_intent_dir, data_type) out_filename = "%s/%s.txt" % (self.out_intent_dir, data_type)
with open(out_filename, 'w') as fw: fw = io.open(out_filename, 'w', encoding="utf8")
for example in examples: for example in examples:
if example[1] not in self.intent_dict: if example[1] not in self.intent_dict:
self.intent_dict[example[1]] = self.intent_id self.intent_dict[example[1]] = self.intent_id
self.intent_id += 1 self.intent_id += 1
fw.write("%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower())) fw.write("%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower()))
with open(self.map_tag_intent, 'w') as fw: fw = io.open(self.map_tag_intent, 'w', encoding="utf8")
for tag in self.intent_dict: for tag in self.intent_dict:
fw.write("%s\t%s\n" % (tag, self.intent_dict[tag])) fw.write("%s\t%s\n" % (tag, self.intent_dict[tag]))
def _parser_slot_data(self, examples, data_type): def _parser_slot_data(self, examples, data_type):
""" """
parser slot dataset parser slot dataset
""" """
out_filename = "%s/%s.txt" % (self.out_slot_dir, data_type) out_filename = "%s/%s.txt" % (self.out_slot_dir, data_type)
with open(out_filename, 'w') as fw: fw = io.open(out_filename, 'w', encoding="utf8")
for example in examples: for example in examples:
tags = [] tags = []
text = example[0] text = example[0]
entities = example[2] entities = example[2]
if not entities: if not entities:
tags = [str(self.slot_dict['O'])] * len(text.strip().split()) tags = [str(self.slot_dict['O'])] * len(text.strip().split())
continue continue
for i in range(len(entities)): for i in range(len(entities)):
enty = entities[i] enty = entities[i]
start = enty['start'] start = enty['start']
value_num = len(enty['value'].split()) value_num = len(enty['value'].split())
tags_slot = [] tags_slot = []
for j in range(value_num): for j in range(value_num):
if j == 0: if j == 0:
bround_tag = "B" bround_tag = "B"
else:
bround_tag = "I"
tag = "%s-%s" % (bround_tag, enty['entity'])
if tag not in self.slot_dict:
self.slot_dict[tag] = self.slot_id
self.slot_id += 1
tags_slot.append(str(self.slot_dict[tag]))
if i == 0:
if start not in [0, 1]:
prefix_num = len(text[: start].strip().split())
tags.extend([str(self.slot_dict['O'])] * prefix_num)
tags.extend(tags_slot)
else: else:
prefix_num = len(text[entities[i - 1]['end']: start].strip().split()) bround_tag = "I"
tag = "%s-%s" % (bround_tag, enty['entity'])
if tag not in self.slot_dict:
self.slot_dict[tag] = self.slot_id
self.slot_id += 1
tags_slot.append(str(self.slot_dict[tag]))
if i == 0:
if start not in [0, 1]:
prefix_num = len(text[: start].strip().split())
tags.extend([str(self.slot_dict['O'])] * prefix_num) tags.extend([str(self.slot_dict['O'])] * prefix_num)
tags.extend(tags_slot) tags.extend(tags_slot)
if entities[-1]['end'] < len(text): else:
suffix_num = len(text[entities[-1]['end']:].strip().split()) prefix_num = len(text[entities[i - 1]['end']: start].strip().split())
tags.extend([str(self.slot_dict['O'])] * suffix_num) tags.extend([str(self.slot_dict['O'])] * prefix_num)
fw.write("%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8'))) tags.extend(tags_slot)
if entities[-1]['end'] < len(text):
suffix_num = len(text[entities[-1]['end']:].strip().split())
tags.extend([str(self.slot_dict['O'])] * suffix_num)
fw.write("%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8')))
with open(self.map_tag_slot, 'w') as fw: fw = io.open(self.map_tag_slot, 'w', encoding="utf8")
for slot in self.slot_dict: for slot in self.slot_dict:
fw.write("%s\t%s\n" % (slot, self.slot_dict[slot])) fw.write("%s\t%s\n" % (slot, self.slot_dict[slot]))
def get_train_dataset(self): def get_train_dataset(self):
""" """
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -17,6 +18,7 @@ import json ...@@ -17,6 +18,7 @@ import json
import sys import sys
import csv import csv
import os import os
import io
import re import re
import commonlib import commonlib
...@@ -55,17 +57,17 @@ class DSTC2(object): ...@@ -55,17 +57,17 @@ class DSTC2(object):
""" """
tag_id = 1 tag_id = 1
self.map_tag_dict['none'] = 0 self.map_tag_dict['none'] = 0
with open(self.onto_json, 'r') as fr: fr = io.open(self.onto_json, 'r', encoding="utf8")
ontology = json.load(fr) ontology = json.load(fr)
slots_values = ontology['informable'] slots_values = ontology['informable']
for slot in slots_values: for slot in slots_values:
for value in slots_values[slot]: for value in slots_values[slot]:
key = "%s_%s" % (slot, value) key = "%s_%s" % (slot, value)
self.map_tag_dict[key] = tag_id
tag_id += 1
key = "%s_none" % (slot)
self.map_tag_dict[key] = tag_id self.map_tag_dict[key] = tag_id
tag_id += 1 tag_id += 1
key = "%s_none" % (slot)
self.map_tag_dict[key] = tag_id
tag_id += 1
def _parser_dataset(self, data_type): def _parser_dataset(self, data_type):
""" """
...@@ -79,31 +81,33 @@ class DSTC2(object): ...@@ -79,31 +81,33 @@ class DSTC2(object):
os.makedirs(self.out_asr_dir) os.makedirs(self.out_asr_dir)
out_file = os.path.join(self.out_dir, "%s.txt" % data_type) out_file = os.path.join(self.out_dir, "%s.txt" % data_type)
out_asr_file = os.path.join(self.out_asr_dir, "%s.txt" % data_type) out_asr_file = os.path.join(self.out_asr_dir, "%s.txt" % data_type)
with open(out_file, 'w') as fw, open(out_asr_file, 'w') as fw_asr: fw = io.open(out_file, 'w', encoding="utf8")
data_list = self.data_dict.get(data_type) fw_asr = io.open(out_asr_file, 'w', encoding="utf8")
for fn in data_list: data_list = self.data_dict.get(data_type)
log_file = os.path.join(fn, "log.json") for fn in data_list:
label_file = os.path.join(fn, "label.json") log_file = os.path.join(fn, "log.json")
with open(log_file, 'r') as f_log, open(label_file, 'r') as f_label: label_file = os.path.join(fn, "label.json")
log_json = json.load(f_log) f_log = io.open(log_file, 'r', encoding="utf8")
label_json = json.load(f_label) f_label = io.open(label_file, 'r', encoding="utf8")
session_id = log_json['session-id'] log_json = json.load(f_log)
assert len(label_json["turns"]) == len(log_json["turns"]) label_json = json.load(f_label)
for i in range(len(label_json["turns"])): session_id = log_json['session-id']
log_turn = log_json["turns"][i] assert len(label_json["turns"]) == len(log_json["turns"])
label_turn = label_json["turns"][i] for i in range(len(label_json["turns"])):
assert log_turn["turn-index"] == label_turn["turn-index"] log_turn = log_json["turns"][i]
labels = ["%s_%s" % (slot, label_turn["goal-labels"][slot]) for slot in label_turn["goal-labels"]] label_turn = label_json["turns"][i]
labels_ids = " ".join([str(self.map_tag_dict.get(label, self.map_tag_dict["%s_none" % label.split('_')[0]])) for label in labels]) assert log_turn["turn-index"] == label_turn["turn-index"]
mach = log_turn['output']['transcript'] labels = ["%s_%s" % (slot, label_turn["goal-labels"][slot]) for slot in label_turn["goal-labels"]]
user = label_turn['transcription'] labels_ids = " ".join([str(self.map_tag_dict.get(label, self.map_tag_dict["%s_none" % label.split('_')[0]])) for label in labels])
if not labels_ids.strip(): mach = log_turn['output']['transcript']
labels_ids = self.map_tag_dict['none'] user = label_turn['transcription']
out = "%s\t%s\1%s\t%s" % (session_id, mach, user, labels_ids) if not labels_ids.strip():
user_asr = log_turn['input']['live']['asr-hyps'][0]['asr-hyp'].strip() labels_ids = self.map_tag_dict['none']
out_asr = "%s\t%s\1%s\t%s" % (session_id, mach, user_asr, labels_ids) out = "%s\t%s\1%s\t%s" % (session_id, mach, user, labels_ids)
fw.write("%s\n" % out.encode('utf8')) user_asr = log_turn['input']['live']['asr-hyps'][0]['asr-hyp'].strip()
fw_asr.write("%s\n" % out_asr.encode('utf8')) out_asr = "%s\t%s\1%s\t%s" % (session_id, mach, user_asr, labels_ids)
fw.write("%s\n" % out.encode('utf8'))
fw_asr.write("%s\n" % out_asr.encode('utf8'))
def get_train_dataset(self): def get_train_dataset(self):
""" """
...@@ -127,9 +131,9 @@ class DSTC2(object): ...@@ -127,9 +131,9 @@ class DSTC2(object):
""" """
get tag and map ids file get tag and map ids file
""" """
with open(self.map_tag, 'w') as fw: fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,6 +17,7 @@ ...@@ -16,6 +17,7 @@
import sys import sys
import csv import csv
import os import os
import io
import re import re
import commonlib import commonlib
...@@ -64,18 +66,18 @@ class MRDA(object): ...@@ -64,18 +66,18 @@ class MRDA(object):
dadb_list = self.data_dict[data_type] dadb_list = self.data_dict[data_type]
for dadb_key in dadb_list: for dadb_key in dadb_list:
dadb_file = self.dadb_dict[dadb_key] dadb_file = self.dadb_dict[dadb_key]
with open(dadb_file, 'r') as fr: fr = io.open(dadb_file, 'r', encoding="utf8")
row = csv.reader(fr, delimiter = ',') row = csv.reader(fr, delimiter = ',')
for line in row: for line in row:
elems = line elems = line
conv_id = elems[2] conv_id = elems[2]
conv_id_list.append(conv_id) conv_id_list.append(conv_id)
if len(elems) != 14: if len(elems) != 14:
continue continue
error_code = elems[3] error_code = elems[3]
da_tag = elems[-9] da_tag = elems[-9]
da_ori_tag = elems[-6] da_ori_tag = elems[-6]
dadb_dict[conv_id] = (error_code, da_ori_tag, da_tag) dadb_dict[conv_id] = (error_code, da_ori_tag, da_tag)
return dadb_dict, conv_id_list return dadb_dict, conv_id_list
def load_trans(self, data_type): def load_trans(self, data_type):
...@@ -84,16 +86,16 @@ class MRDA(object): ...@@ -84,16 +86,16 @@ class MRDA(object):
trans_list = self.data_dict[data_type] trans_list = self.data_dict[data_type]
for trans_key in trans_list: for trans_key in trans_list:
trans_file = self.trans_dict[trans_key] trans_file = self.trans_dict[trans_key]
with open(trans_file, 'r') as fr: fr = io.open(trans_file, 'r', encoding="utf8")
row = csv.reader(fr, delimiter = ',') row = csv.reader(fr, delimiter = ',')
for line in row: for line in row:
elems = line elems = line
if len(elems) != 3: if len(elems) != 3:
continue continue
conv_id = elems[0] conv_id = elems[0]
text = elems[1] text = elems[1]
text_process = elems[2] text_process = elems[2]
trans_dict[conv_id] = (text, text_process) trans_dict[conv_id] = (text, text_process)
return trans_dict return trans_dict
def _parser_dataset(self, data_type): def _parser_dataset(self, data_type):
...@@ -103,23 +105,23 @@ class MRDA(object): ...@@ -103,23 +105,23 @@ class MRDA(object):
out_filename = "%s/%s.txt" % (self.out_dir, data_type) out_filename = "%s/%s.txt" % (self.out_dir, data_type)
dadb_dict, conv_id_list = self.load_dadb(data_type) dadb_dict, conv_id_list = self.load_dadb(data_type)
trans_dict = self.load_trans(data_type) trans_dict = self.load_trans(data_type)
with open(out_filename, 'w') as fw: fw = io.open(out_filename, 'w', encoding="utf8")
for elem in conv_id_list: for elem in conv_id_list:
v_dadb = dadb_dict[elem] v_dadb = dadb_dict[elem]
v_trans = trans_dict[elem] v_trans = trans_dict[elem]
da_tag = v_dadb[2] da_tag = v_dadb[2]
if da_tag not in self.tag_dict: if da_tag not in self.tag_dict:
continue continue
tag = self.tag_dict[da_tag] tag = self.tag_dict[da_tag]
if tag == "Z": if tag == "Z":
continue continue
if tag not in self.map_tag_dict: if tag not in self.map_tag_dict:
self.map_tag_dict[tag] = self.tag_id self.map_tag_dict[tag] = self.tag_id
self.tag_id += 1 self.tag_id += 1
caller = elem.split('_')[0].split('-')[-1] caller = elem.split('_')[0].split('-')[-1]
conv_no = elem.split('_')[0].split('-')[0] conv_no = elem.split('_')[0].split('-')[0]
out = "%s\t%s\t%s\t%s" % (conv_no, self.map_tag_dict[tag], caller, v_trans[0]) out = "%s\t%s\t%s\t%s" % (conv_no, self.map_tag_dict[tag], caller, v_trans[0])
fw.write("%s\n" % out) fw.write("%s\n" % out)
def get_train_dataset(self): def get_train_dataset(self):
""" """
...@@ -143,9 +145,9 @@ class MRDA(object): ...@@ -143,9 +145,9 @@ class MRDA(object):
""" """
get tag and map ids file get tag and map ids file
""" """
with open(self.map_tag, 'w') as fw: fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,6 +17,7 @@ ...@@ -16,6 +17,7 @@
import sys import sys
import csv import csv
import os import os
import io
import re import re
import commonlib import commonlib
...@@ -56,18 +58,18 @@ class SWDA(object): ...@@ -56,18 +58,18 @@ class SWDA(object):
parser train dev test dataset parser train dev test dataset
""" """
out_filename = "%s/%s.txt" % (self.out_dir, data_type) out_filename = "%s/%s.txt" % (self.out_dir, data_type)
with open(out_filename, 'w') as fw: fw = io.open(out_filename, 'w', encoding='utf8')
for name in self.data_dict[data_type]: for name in self.data_dict[data_type]:
file_path = self.file_dict[name] file_path = self.file_dict[name]
with open(file_path, 'r') as fr: fr = io.open(file_path, 'r', encoding="utf8")
idx = 0 idx = 0
row = csv.reader(fr, delimiter = ',') row = csv.reader(fr, delimiter = ',')
for r in row: for r in row:
if idx == 0: if idx == 0:
idx += 1 idx += 1
continue continue
out = self._parser_utterence(r) out = self._parser_utterence(r)
fw.write("%s\n" % out) fw.write("%s\n" % out)
def _clean_text(self, text): def _clean_text(self, text):
""" """
...@@ -209,9 +211,9 @@ class SWDA(object): ...@@ -209,9 +211,9 @@ class SWDA(object):
""" """
get tag and map ids file get tag and map ids file
""" """
with open(self.map_tag, 'w') as fw: fw = io.open(self.map_tag, 'w', encoding='utf8')
for elem in self.map_tag_dict: for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem])) fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
def main(self): def main(self):
""" """
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -13,6 +14,7 @@ ...@@ -13,6 +14,7 @@
# limitations under the License. # limitations under the License.
"""common function""" """common function"""
import sys import sys
import io
import os import os
...@@ -48,13 +50,13 @@ def load_dict(conf): ...@@ -48,13 +50,13 @@ def load_dict(conf):
load swda dataset config load swda dataset config
""" """
conf_dict = dict() conf_dict = dict()
with open(conf, 'r') as fr: fr = io.open(conf, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip() line = line.strip()
elems = line.split('\t') elems = line.split('\t')
if elems[0] not in conf_dict: if elems[0] not in conf_dict:
conf_dict[elems[0]] = [] conf_dict[elems[0]] = []
conf_dict[elems[0]].append(elems[1]) conf_dict[elems[0]].append(elems[1])
return conf_dict return conf_dict
...@@ -63,11 +65,11 @@ def load_voc(conf): ...@@ -63,11 +65,11 @@ def load_voc(conf):
load map dict load map dict
""" """
map_dict = {} map_dict = {}
with open(conf, 'r') as fr: fr = io.open(conf, 'r', encoding="utf8")
for line in fr: for line in fr:
line = line.strip() line = line.strip()
elems = line.split('\t') elems = line.split('\t')
map_dict[elems[0]] = elems[1] map_dict[elems[0]] = elems[1]
return map_dict return map_dict
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import sys
import os
from build_atis_dataset import ATIS
from build_dstc2_dataset import DSTC2
from build_mrda_dataset import MRDA
from build_swda_dataset import SWDA
if __name__ == "__main__":
task_name = sys.argv[1]
task_name = task_name.lower()
if task_name not in ['swda', 'mrda', 'atis', 'dstc2', 'udc']:
print("task name error: we support [swda|mrda|atis|dstc2|udc]")
exit(1)
if task_name == 'swda':
swda_inst = SWDA()
swda_inst.main()
elif task_name == 'mrda':
mrda_inst = MRDA()
mrda_inst.main()
elif task_name == 'atis':
atis_inst = ATIS()
atis_inst.main()
shutil.copyfile("../../data/input/data/atis/atis_slot/test.txt", "../../data/input/data/atis/atis_slot/dev.txt")
shutil.copyfile("../../data/input/data/atis/atis_intent/test.txt", "../../data/input/data/atis/atis_intent/dev.txt")
elif task_name == 'dstc2':
dstc_inst = DSTC2()
dstc_inst.main()
else:
exit(0)
#!/bin/bash
TASK_DATA=$1
typeset -l TASK_DATA
if [ "${TASK_DATA}" = "udc" ]
then
exit 0
elif [ "${TASK_DATA}" = "swda" ]
then
python build_swda_dataset.py
elif [ "${TASK_DATA}" = "mrda" ]
then
python build_mrda_dataset.py
elif [[ "${TASK_DATA}" =~ "atis" ]]
then
python build_atis_dataset.py
cat ../../data/input/data/atis/atis_slot/test.txt > ../../data/input/data/atis/atis_slot/dev.txt
cat ../../data/input/data/atis/atis_intent/test.txt > ../../data/input/data/atis/atis_intent/dev.txt
elif [ "${TASK_DATA}" = "dstc2" ]
then
python build_dstc2_dataset.py
else
echo "can not support $TASK_DATA , please choose [swda|mrda|atis|dstc2|multi-woz]"
fi
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -21,6 +22,7 @@ from __future__ import print_function ...@@ -21,6 +22,7 @@ from __future__ import print_function
import collections import collections
import unicodedata import unicodedata
import six import six
import io
def convert_to_unicode(text): def convert_to_unicode(text):
...@@ -69,7 +71,7 @@ def printable_text(text): ...@@ -69,7 +71,7 @@ def printable_text(text):
def load_vocab(vocab_file): def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary.""" """Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict() vocab = collections.OrderedDict()
fin = open(vocab_file) fin = io.open(vocab_file, 'r', encoding="utf8")
for num, line in enumerate(fin): for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t") items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2: if len(items) > 2:
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -16,6 +17,7 @@ from __future__ import absolute_import ...@@ -16,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import io
import os import os
import sys import sys
import argparse import argparse
...@@ -38,8 +40,8 @@ class JsonConfig(object): ...@@ -38,8 +40,8 @@ class JsonConfig(object):
def _parse(self, config_path): def _parse(self, config_path):
try: try:
with open(config_path) as json_file: json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file) config_dict = json.load(json_file)
except: except:
raise IOError("Error in parsing bert model config file '%s'" % raise IOError("Error in parsing bert model config file '%s'" %
config_path) config_path)
...@@ -212,9 +214,9 @@ class PDConfig(object): ...@@ -212,9 +214,9 @@ class PDConfig(object):
raise Warning("the json file %s does not exist." % file_path) raise Warning("the json file %s does not exist." % file_path)
return return
with open(file_path, "r") as fin: fin = io.open(file_path, "r", encoding="utf8")
self.json_config = json.loads(fin.read()) self.json_config = json.loads(fin.read())
fin.close() fin.close()
if fuse_args: if fuse_args:
for name in self.json_config: for name in self.json_config:
...@@ -236,9 +238,9 @@ class PDConfig(object): ...@@ -236,9 +238,9 @@ class PDConfig(object):
raise Warning("the yaml file %s does not exist." % file_path) raise Warning("the yaml file %s does not exist." % file_path)
return return
with open(file_path, "r") as fin: fin = io.open(file_path, "r", encoding="utf8")
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close() fin.close()
if fuse_args: if fuse_args:
for name in self.yaml_config: for name in self.yaml_config:
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,6 +13,7 @@ ...@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import io
import os import os
import sys import sys
import numpy as np import numpy as np
...@@ -142,15 +144,16 @@ def do_predict(args): ...@@ -142,15 +144,16 @@ def do_predict(args):
np.set_printoptions(precision=4, suppress=True) np.set_printoptions(precision=4, suppress=True)
print("Write the predicted results into the output_prediction_file") print("Write the predicted results into the output_prediction_file")
with open(args.output_prediction_file, 'w') as fw:
if task_name not in ['atis_slot']: fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
for index, result in enumerate(all_results): if task_name not in ['atis_slot']:
tags = pred_func(result) for index, result in enumerate(all_results):
fw.write("%s\t%s\n" % (index, tags)) tags = pred_func(result)
else: fw.write("%s\t%s\n" % (index, tags))
tags = pred_func(all_results, args.max_seq_len) else:
for index, tag in enumerate(tags): tags = pred_func(all_results, args.max_seq_len)
fw.write("%s\t%s\n" % (index, tag)) for index, tag in enumerate(tags):
fw.write("%s\t%s\n" % (index, tag))
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -21,7 +21,6 @@ import os ...@@ -21,7 +21,6 @@ import os
import sys import sys
import time import time
import numpy as np import numpy as np
import multiprocessing
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -111,8 +110,7 @@ def do_train(args): ...@@ -111,8 +110,7 @@ def do_train(args):
if args.use_cuda: if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count() dev_count = fluid.core.get_cuda_device_count()
else: else:
dev_count = int( dev_count = int(os.environ.get('CPU_NUM', 1))
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
batch_generator = processor.data_generator( batch_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size,
......
...@@ -44,10 +44,10 @@ In our MTL experiments, we use BERT as our shared encoder. The parameters are in ...@@ -44,10 +44,10 @@ In our MTL experiments, we use BERT as our shared encoder. The parameters are in
``` ```
1、cd scripts 1、cd scripts
2、download cased_model_01.tar.gz from link 2、# download cased_model_01.tar.gz from link
3、mkdir cased_model_01 && mv cased_model_01.tar.gz cased_model_01 && cd cased_model_01 && tar -xvf cased_model_01.tar.gz && cd .. 3、mkdir cased_model_01 && mv cased_model_01.tar.gz cased_model_01 && cd cased_model_01 && tar -xvf cased_model_01.tar.gz && cd ..
4、python convert_model_params.py --init_tf_checkpoint cased_model_01/model.ckpt --fluid_params_dir params 4、python convert_model_params.py --init_tf_checkpoint cased_model_01/model.ckpt --fluid_params_dir params
5、mkdir fluid_models && mv cased_model_01/vocab.txt cased_model_01/bert_config.json params fluid_models 5、mkdir squad2_model && mv cased_model_01/vocab.txt cased_model_01/bert_config.json params squad2_model
``` ```
Alternatively, user can directly **download the parameters that we have converted**: Alternatively, user can directly **download the parameters that we have converted**:
......
...@@ -21,8 +21,8 @@ import json ...@@ -21,8 +21,8 @@ import json
import random import random
import collections import collections
import numpy as np import numpy as np
import tokenization from task_reader import tokenization
from batching import prepare_batch_data from task_reader.batching import prepare_batch_data
class MRQAExample(object): class MRQAExample(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册