未验证 提交 d7f8866b 编写于 作者: 0 0YuanZhang0 提交者: GitHub

Fix windows (#3646)

* fix_dgu_ade_dnet

* fix_readme
上级 e2e782c5
......@@ -82,7 +82,7 @@ label_data(第二阶段finetuning数据集)
    数据集、相关模型下载:
```
cd ade && bash prepare_data_and_model.sh
python ade/prepare_data_and_model.py
```
    数据路径:data/input/data/
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import shutil
import urllib
import sys
import io
import os
URLLIB=urllib
if sys.version_info >= (3, 0):
URLLIB=urllib.request
DATA_MODEL_PATH = {"DATA_PATH": "https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_dataset-1.0.0.tar.gz",
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.2.0.0.tar.gz"}
PATH_MAP = {'DATA_PATH': "./data/input",
'TRAINED_MODEL': './data/saved_models'}
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path = dir_name)
return True
except Exception as e:
print(e)
return False
def download_model_and_data():
print("Downloading ade data, pretrain model and trained models......")
print("This process is quite long, please wait patiently............")
for path in ['./data/input/data', './data/saved_models/trained_models']:
if not os.path.exists(path):
continue
shutil.rmtree(path)
for path_key in DATA_MODEL_PATH:
filename = os.path.basename(DATA_MODEL_PATH[path_key])
URLLIB.urlretrieve(DATA_MODEL_PATH[path_key], os.path.join("./", filename))
state = un_tar(filename, PATH_MAP[path_key])
if not state:
print("Tar %s error....." % path_key)
return False
os.remove(filename)
return True
if __name__ == "__main__":
state = download_model_and_data()
if not state:
exit(1)
print("Downloading data and models sucess......")
#!/bin/bash
#check data directory
cd ..
echo "Start download data and models.............."
if [ ! -d "data" ]; then
echo "Directory data does not exist, make new data directory"
mkdir data
fi
cd data
#check configure file
if [ ! -d "config" ]; then
echo "config directory not exist........"
exit 255
else
if [ ! -f "config/ade.yaml" ]; then
echo "config file dgu.yaml has been lost........"
exit 255
fi
fi
#check and download input data
if [ ! -d "input" ]; then
echo "Directory input does not exist, make new input directory"
mkdir input
fi
cd input
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_dataset-1.0.0.tar.gz
tar -zxvf auto_dialogue_evaluation_dataset-1.0.0.tar.gz
rm auto_dialogue_evaluation_dataset-1.0.0.tar.gz
cd ..
#check and download pretrain model
if [ ! -d "pretrain_model" ]; then
echo "Directory pretrain_model does not exist, make new pretrain_model directory"
mkdir pretrain_model
fi
#check and download inferenece model
if [ ! -d "inference_models" ]; then
echo "Directory inferenece_model does not exist, make new inferenece_model directory"
mkdir inference_models
fi
#check output
if [ ! -d "output" ]; then
echo "Directory output does not exist, make new output directory"
mkdir output
fi
#check saved model
if [ ! -d "saved_models" ]; then
echo "Directory saved_models does not exist, make new saved_models directory"
mkdir saved_models
fi
cd saved_models
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/auto_dialogue_evaluation_models.2.0.0.tar.gz
tar -xvf auto_dialogue_evaluation_models.2.0.0.tar.gz
rm auto_dialogue_evaluation_models.2.0.0.tar.gz
echo "Finish.............."
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,6 +14,7 @@
# limitations under the License.
"""Reader for auto dialogue evaluation"""
import io
import sys
import time
import random
......@@ -34,7 +36,7 @@ class DataProcessor(object):
"""load examples"""
examples = []
index = 0
with open(self.data_file, 'r') as fr:
fr = io.open(self.data_file, 'r', encoding="utf8")
for line in fr:
if index !=0 and index % 100 == 0:
print("processing data: %d" % index)
......@@ -47,7 +49,7 @@ class DataProcessor(object):
if phase not in ['train', 'dev', 'test']:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
count = len(open(self.data_file,'rU').readlines())
count = len(io.open(self.data_file, 'r', encoding="utf8").readlines())
self.num_examples[phase] = count
return self.num_examples[phase]
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,6 +18,7 @@ from __future__ import division
from __future__ import print_function
import os
import io
import sys
import argparse
import json
......@@ -38,7 +40,7 @@ class JsonConfig(object):
def _parse(self, config_path):
try:
with open(config_path) as json_file:
json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
......@@ -214,7 +216,7 @@ class PDConfig(object):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
fin = io.open(file_path, "r", encoding="utf8")
self.json_config = json.loads(fin.read())
fin.close()
......@@ -238,7 +240,7 @@ class PDConfig(object):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
fin = io.open(file_path, "r", encoding="utf8")
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,6 +14,7 @@
# limitations under the License.
"""evaluation metrics"""
import io
import os
import sys
import numpy as np
......@@ -24,7 +26,7 @@ from ade.utils.configure import PDConfig
def do_eval(args):
"""evaluate metrics"""
labels = []
with open(args.evaluation_file, 'r') as fr:
fr = io.open(args.evaluation_file, 'r', encoding="utf8")
for line in fr:
tokens = line.strip().split('\t')
assert len(tokens) == 3
......@@ -32,7 +34,7 @@ def do_eval(args):
labels.append(label)
scores = []
with open(args.output_prediction_file, 'r') as fr:
fr = io.open(args.output_prediction_file, 'r', encoding="utf8")
for line in fr:
tokens = line.strip().split('\t')
assert len(tokens) == 2
......
......@@ -18,7 +18,6 @@ import sys
import six
import numpy as np
import time
import multiprocessing
import paddle
import paddle.fluid as fluid
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""predict auto dialogue evaluation task"""
import io
import os
import sys
import six
import time
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -109,7 +109,7 @@ def do_predict(args):
scores = scores[: num_test_examples]
print("Write the predicted results into the output_prediction_file")
with open(args.output_prediction_file, 'w') as fw:
fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
for index, score in enumerate(scores):
fw.write("%s\t%s\n" % (index, score))
print("finish........................................")
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,13 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""train auto dialogue evaluation task"""
import io
import os
import sys
import six
import time
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -76,8 +76,7 @@ def do_train(args):
dev_count = fluid.core.get_cuda_device_count()
place = fluid.CUDAPlace(int(os.getenv('FLAGS_selected_gpus', '0')))
else:
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = int(os.environ.get('CPU_NUM', 1))
place = fluid.CPUPlace()
processor = reader.DataProcessor(
......@@ -115,9 +114,9 @@ def do_train(args):
if args.word_emb_init:
print("start loading word embedding init ...")
if six.PY2:
word_emb = np.array(pickle.load(open(args.word_emb_init, 'rb'))).astype('float32')
word_emb = np.array(pickle.load(io.open(args.word_emb_init, 'rb'))).astype('float32')
else:
word_emb = np.array(pickle.load(open(args.word_emb_init, 'rb'), encoding="bytes")).astype('float32')
word_emb = np.array(pickle.load(io.open(args.word_emb_init, 'rb'), encoding="bytes")).astype('float32')
set_word_embedding(word_emb, place)
print("finish init word embedding ...")
......
......@@ -63,7 +63,7 @@ SWDA:Switchboard Dialogue Act Corpus;
    数据集、相关模型下载:
```
cd dgu && bash prepare_data_and_model.sh
python dgu/prepare_data_and_model.py
```
    数据路径:data/input/data
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -18,6 +19,7 @@ from __future__ import division
from __future__ import print_function
import os
import io
import sys
import six
import json
......@@ -33,7 +35,7 @@ class BertConfig(object):
def _parse(self, config_path):
try:
with open(config_path) as json_file:
json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file)
except Exception:
raise IOError("Error in parsing bert model config file '%s'" %
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -14,6 +15,7 @@
"""evaluate task metrics"""
import sys
import io
class EvalDA(object):
......@@ -33,12 +35,12 @@ class EvalDA(object):
"""
pred_label = []
refer_label = []
with open(self.refer_file, 'r') as fr:
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
label = line.rstrip('\n').split('\t')[1]
refer_label.append(int(label))
idx = 0
with open(self.pred_file, 'r') as fr:
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
......@@ -78,12 +80,12 @@ class EvalATISIntent(object):
"""
pred_label = []
refer_label = []
with open(self.refer_file, 'r') as fr:
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
label = line.rstrip('\n').split('\t')[0]
refer_label.append(int(label))
idx = 0
with open(self.pred_file, 'r') as fr:
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
......@@ -123,12 +125,12 @@ class EvalATISSlot(object):
"""
pred_label = []
refer_label = []
with open(self.refer_file, 'r') as fr:
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
labels = line.rstrip('\n').split('\t')[1].split()
labels = [int(l) for l in labels]
refer_label.append(labels)
with open(self.pred_file, 'r') as fr:
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
if len(line.split('\t')) != 2 or not line[0].isdigit():
continue
......@@ -208,12 +210,12 @@ class EvalUDC(object):
"""
data = []
refer_label = []
with open(self.refer_file, 'r') as fr:
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
label = line.rstrip('\n').split('\t')[0]
refer_label.append(label)
idx = 0
with open(self.pred_file, 'r') as fr:
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
elems = line.rstrip('\n').split('\t')
if len(elems) != 2 or not elems[0].isdigit():
......@@ -281,14 +283,14 @@ class EvalDSTC2(object):
"""
pred_label = []
refer_label = []
with open(self.refer_file, 'r') as fr:
fr = io.open(self.refer_file, 'r', encoding="utf8")
for line in fr:
line = line.strip('\n')
labels = [int(l) for l in line.split('\t')[-1].split()]
labels = sorted(list(set(labels)))
refer_label.append(" ".join([str(l) for l in labels]))
all_pred = []
with open(self.pred_file, 'r') as fr:
fr = io.open(self.pred_file, 'r', encoding="utf8")
for line in fr:
line = line.strip('\n')
all_pred.append(line)
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tarfile
import shutil
import urllib
import sys
import io
import os
URLLIB=urllib
if sys.version_info >= (3, 0):
URLLIB=urllib.request
DATA_MODEL_PATH = {"DATA_PATH": "https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz",
"PRETRAIN_MODEL": "https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz",
"TRAINED_MODEL": "https://baidu-nlp.bj.bcebos.com/dgu_models_2.0.0.tar.gz"}
PATH_MAP = {'DATA_PATH': "./data/input",
'PRETRAIN_MODEL': './data/pretrain_model',
'TRAINED_MODEL': './data/saved_models'}
def un_tar(tar_name, dir_name):
try:
t = tarfile.open(tar_name)
t.extractall(path = dir_name)
return True
except Exception as e:
print(e)
return False
def download_model_and_data():
print("Downloading dgu data, pretrain model and trained models......")
print("This process is quite long, please wait patiently............")
for path in ['./data/input/data', './data/pretrain_model/uncased_L-12_H-768_A-12', './data/saved_models/trained_models']:
if not os.path.exists(path):
continue
shutil.rmtree(path)
for path_key in DATA_MODEL_PATH:
filename = os.path.basename(DATA_MODEL_PATH[path_key])
URLLIB.urlretrieve(DATA_MODEL_PATH[path_key], os.path.join("./", filename))
state = un_tar(filename, PATH_MAP[path_key])
if not state:
print("Tar %s error....." % path_key)
return False
os.remove(filename)
return True
if __name__ == "__main__":
state = download_model_and_data()
if not state:
exit(1)
print("Downloading data and models sucess......")
#!/bin/bash
#check data directory
cd ..
echo "Start download data and models.............."
if [ ! -d "data" ]; then
echo "Directory data does not exist, make new data directory"
mkdir data
fi
cd data
#check configure file
if [ ! -d "config" ]; then
echo "config directory not exist........"
exit 255
else
if [ ! -f "config/dgu.yaml" ]; then
echo "config file dgu.yaml has been lost........"
exit 255
fi
fi
#check and download input data
if [ ! -d "input" ]; then
echo "Directory input does not exist, make new input directory"
mkdir input
fi
cd input
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/dmtk_data_1.0.0.tar.gz
tar -xvf dmtk_data_1.0.0.tar.gz
rm dmtk_data_1.0.0.tar.gz
cd ..
#check and download pretrain model
if [ ! -d "pretrain_model" ]; then
echo "Directory pretrain_model does not exist, make new pretrain_model directory"
mkdir pretrain_model
fi
cd pretrain_model
wget --no-check-certificate https://bert-models.bj.bcebos.com/uncased_L-12_H-768_A-12.tar.gz
tar -xvf uncased_L-12_H-768_A-12.tar.gz
rm uncased_L-12_H-768_A-12.tar.gz
cd ..
#check and download inferenece model
if [ ! -d "inference_models" ]; then
echo "Directory inferenece_model does not exist, make new inferenece_model directory"
mkdir inference_models
fi
#check output
if [ ! -d "output" ]; then
echo "Directory output does not exist, make new output directory"
mkdir output
fi
#check saved model
if [ ! -d "saved_models" ]; then
echo "Directory saved_models does not exist, make new saved_models directory"
mkdir saved_models
fi
cd saved_models
wget --no-check-certificate https://baidu-nlp.bj.bcebos.com/dgu_models_2.0.0.tar.gz
tar -xvf dgu_models_2.0.0.tar.gz
rm dgu_models_2.0.0.tar.gz
cd ..
echo "Finish.............."
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,6 +14,7 @@
# limitations under the License.
"""data reader"""
import os
import io
import csv
import sys
import types
......@@ -107,7 +109,7 @@ class DataProcessor(object):
@classmethod
def _read_tsv(cls, input_file, quotechar=None):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
f = io.open(input_file, "r", encoding="utf8")
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
lines = []
for line in reader:
......
scripts:运行数据处理脚本目录, 将官方公开数据集转换成模型所需训练数据格式
运行命令:
sh run_build_data.sh [udc|swda|mrda|atis|dstc2]
python run_build_data.py [udc|swda|mrda|atis|dstc2]
1)、生成MATCHING任务所需要的训练集、开发集、测试集时:
sh run_build_data.sh udc
python run_build_data.py udc
生成数据在dialogue_general_understanding/data/input/data/udc
2)、生成DA任务所需要的训练集、开发集、测试集时:
sh run_build_data.sh swda
sh run_build_data.sh mrda
python run_build_data.py swda
python run_build_data.py mrda
生成数据分别在dialogue_general_understanding/data/input/data/swda和dialogue_general_understanding/data/input/data/mrda
3)、生成DST任务所需的训练集、开发集、测试集时:
sh run_build_data.sh dstc2
python run_build_data.py dstc2
生成数据分别在dialogue_general_understanding/data/input/data/dstc2
4)、生成意图解析, 槽位识别任务所需训练集、开发集、测试集时:
sh run_build_data.sh atis
python run_build_data.py atis
生成槽位识别数据在dialogue_general_understanding/data/input/data/atis/atis_slot
生成意图识别数据在dialogue_general_understanding/data/input/data/atis/atis_intent
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -18,6 +19,7 @@ import json
import sys
import csv
import os
import io
import re
......@@ -51,7 +53,7 @@ class ATIS(object):
os.makedirs(self.out_intent_dir)
src_examples = []
json_file = os.path.join(self.src_dir, "%s.json" % data_type)
with open(json_file, 'r') as load_f:
load_f = io.open(json_file, 'r', encoding="utf8")
json_dict = json.load(load_f)
examples = json_dict['rasa_nlu_data']['common_examples']
for example in examples:
......@@ -66,14 +68,14 @@ class ATIS(object):
parser intent dataset
"""
out_filename = "%s/%s.txt" % (self.out_intent_dir, data_type)
with open(out_filename, 'w') as fw:
fw = io.open(out_filename, 'w', encoding="utf8")
for example in examples:
if example[1] not in self.intent_dict:
self.intent_dict[example[1]] = self.intent_id
self.intent_id += 1
fw.write("%s\t%s\n" % (self.intent_dict[example[1]], example[0].lower()))
with open(self.map_tag_intent, 'w') as fw:
fw = io.open(self.map_tag_intent, 'w', encoding="utf8")
for tag in self.intent_dict:
fw.write("%s\t%s\n" % (tag, self.intent_dict[tag]))
......@@ -82,7 +84,7 @@ class ATIS(object):
parser slot dataset
"""
out_filename = "%s/%s.txt" % (self.out_slot_dir, data_type)
with open(out_filename, 'w') as fw:
fw = io.open(out_filename, 'w', encoding="utf8")
for example in examples:
tags = []
text = example[0]
......@@ -119,7 +121,7 @@ class ATIS(object):
tags.extend([str(self.slot_dict['O'])] * suffix_num)
fw.write("%s\t%s\n" % (text.encode('utf8'), " ".join(tags).encode('utf8')))
with open(self.map_tag_slot, 'w') as fw:
fw = io.open(self.map_tag_slot, 'w', encoding="utf8")
for slot in self.slot_dict:
fw.write("%s\t%s\n" % (slot, self.slot_dict[slot]))
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -17,6 +18,7 @@ import json
import sys
import csv
import os
import io
import re
import commonlib
......@@ -55,7 +57,7 @@ class DSTC2(object):
"""
tag_id = 1
self.map_tag_dict['none'] = 0
with open(self.onto_json, 'r') as fr:
fr = io.open(self.onto_json, 'r', encoding="utf8")
ontology = json.load(fr)
slots_values = ontology['informable']
for slot in slots_values:
......@@ -79,12 +81,14 @@ class DSTC2(object):
os.makedirs(self.out_asr_dir)
out_file = os.path.join(self.out_dir, "%s.txt" % data_type)
out_asr_file = os.path.join(self.out_asr_dir, "%s.txt" % data_type)
with open(out_file, 'w') as fw, open(out_asr_file, 'w') as fw_asr:
fw = io.open(out_file, 'w', encoding="utf8")
fw_asr = io.open(out_asr_file, 'w', encoding="utf8")
data_list = self.data_dict.get(data_type)
for fn in data_list:
log_file = os.path.join(fn, "log.json")
label_file = os.path.join(fn, "label.json")
with open(log_file, 'r') as f_log, open(label_file, 'r') as f_label:
f_log = io.open(log_file, 'r', encoding="utf8")
f_label = io.open(label_file, 'r', encoding="utf8")
log_json = json.load(f_log)
label_json = json.load(f_label)
session_id = log_json['session-id']
......@@ -127,7 +131,7 @@ class DSTC2(object):
"""
get tag and map ids file
"""
with open(self.map_tag, 'w') as fw:
fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,6 +17,7 @@
import sys
import csv
import os
import io
import re
import commonlib
......@@ -64,7 +66,7 @@ class MRDA(object):
dadb_list = self.data_dict[data_type]
for dadb_key in dadb_list:
dadb_file = self.dadb_dict[dadb_key]
with open(dadb_file, 'r') as fr:
fr = io.open(dadb_file, 'r', encoding="utf8")
row = csv.reader(fr, delimiter = ',')
for line in row:
elems = line
......@@ -84,7 +86,7 @@ class MRDA(object):
trans_list = self.data_dict[data_type]
for trans_key in trans_list:
trans_file = self.trans_dict[trans_key]
with open(trans_file, 'r') as fr:
fr = io.open(trans_file, 'r', encoding="utf8")
row = csv.reader(fr, delimiter = ',')
for line in row:
elems = line
......@@ -103,7 +105,7 @@ class MRDA(object):
out_filename = "%s/%s.txt" % (self.out_dir, data_type)
dadb_dict, conv_id_list = self.load_dadb(data_type)
trans_dict = self.load_trans(data_type)
with open(out_filename, 'w') as fw:
fw = io.open(out_filename, 'w', encoding="utf8")
for elem in conv_id_list:
v_dadb = dadb_dict[elem]
v_trans = trans_dict[elem]
......@@ -143,7 +145,7 @@ class MRDA(object):
"""
get tag and map ids file
"""
with open(self.map_tag, 'w') as fw:
fw = io.open(self.map_tag, 'w', encoding="utf8")
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,6 +17,7 @@
import sys
import csv
import os
import io
import re
import commonlib
......@@ -56,10 +58,10 @@ class SWDA(object):
parser train dev test dataset
"""
out_filename = "%s/%s.txt" % (self.out_dir, data_type)
with open(out_filename, 'w') as fw:
fw = io.open(out_filename, 'w', encoding='utf8')
for name in self.data_dict[data_type]:
file_path = self.file_dict[name]
with open(file_path, 'r') as fr:
fr = io.open(file_path, 'r', encoding="utf8")
idx = 0
row = csv.reader(fr, delimiter = ',')
for r in row:
......@@ -209,7 +211,7 @@ class SWDA(object):
"""
get tag and map ids file
"""
with open(self.map_tag, 'w') as fw:
fw = io.open(self.map_tag, 'w', encoding='utf8')
for elem in self.map_tag_dict:
fw.write("%s\t%s\n" % (elem, self.map_tag_dict[elem]))
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -13,6 +14,7 @@
# limitations under the License.
"""common function"""
import sys
import io
import os
......@@ -48,7 +50,7 @@ def load_dict(conf):
load swda dataset config
"""
conf_dict = dict()
with open(conf, 'r') as fr:
fr = io.open(conf, 'r', encoding="utf8")
for line in fr:
line = line.strip()
elems = line.split('\t')
......@@ -63,7 +65,7 @@ def load_voc(conf):
load map dict
"""
map_dict = {}
with open(conf, 'r') as fr:
fr = io.open(conf, 'r', encoding="utf8")
for line in fr:
line = line.strip()
elems = line.split('\t')
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import shutil
import sys
import os
from build_atis_dataset import ATIS
from build_dstc2_dataset import DSTC2
from build_mrda_dataset import MRDA
from build_swda_dataset import SWDA
if __name__ == "__main__":
task_name = sys.argv[1]
task_name = task_name.lower()
if task_name not in ['swda', 'mrda', 'atis', 'dstc2', 'udc']:
print("task name error: we support [swda|mrda|atis|dstc2|udc]")
exit(1)
if task_name == 'swda':
swda_inst = SWDA()
swda_inst.main()
elif task_name == 'mrda':
mrda_inst = MRDA()
mrda_inst.main()
elif task_name == 'atis':
atis_inst = ATIS()
atis_inst.main()
shutil.copyfile("../../data/input/data/atis/atis_slot/test.txt", "../../data/input/data/atis/atis_slot/dev.txt")
shutil.copyfile("../../data/input/data/atis/atis_intent/test.txt", "../../data/input/data/atis/atis_intent/dev.txt")
elif task_name == 'dstc2':
dstc_inst = DSTC2()
dstc_inst.main()
else:
exit(0)
#!/bin/bash
TASK_DATA=$1
typeset -l TASK_DATA
if [ "${TASK_DATA}" = "udc" ]
then
exit 0
elif [ "${TASK_DATA}" = "swda" ]
then
python build_swda_dataset.py
elif [ "${TASK_DATA}" = "mrda" ]
then
python build_mrda_dataset.py
elif [[ "${TASK_DATA}" =~ "atis" ]]
then
python build_atis_dataset.py
cat ../../data/input/data/atis/atis_slot/test.txt > ../../data/input/data/atis/atis_slot/dev.txt
cat ../../data/input/data/atis/atis_intent/test.txt > ../../data/input/data/atis/atis_intent/dev.txt
elif [ "${TASK_DATA}" = "dstc2" ]
then
python build_dstc2_dataset.py
else
echo "can not support $TASK_DATA , please choose [swda|mrda|atis|dstc2|multi-woz]"
fi
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -21,6 +22,7 @@ from __future__ import print_function
import collections
import unicodedata
import six
import io
def convert_to_unicode(text):
......@@ -69,7 +71,7 @@ def printable_text(text):
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
fin = io.open(vocab_file, 'r', encoding="utf8")
for num, line in enumerate(fin):
items = convert_to_unicode(line.strip()).split("\t")
if len(items) > 2:
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -16,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import io
import os
import sys
import argparse
......@@ -38,7 +40,7 @@ class JsonConfig(object):
def _parse(self, config_path):
try:
with open(config_path) as json_file:
json_file = io.open(config_path, 'r', encoding="utf8")
config_dict = json.load(json_file)
except:
raise IOError("Error in parsing bert model config file '%s'" %
......@@ -212,7 +214,7 @@ class PDConfig(object):
raise Warning("the json file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
fin = io.open(file_path, "r", encoding="utf8")
self.json_config = json.loads(fin.read())
fin.close()
......@@ -236,7 +238,7 @@ class PDConfig(object):
raise Warning("the yaml file %s does not exist." % file_path)
return
with open(file_path, "r") as fin:
fin = io.open(file_path, "r", encoding="utf8")
self.yaml_config = yaml.load(fin, Loader=yaml.SafeLoader)
fin.close()
......
# -*- coding: utf-8 -*-
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import os
import sys
import numpy as np
......@@ -142,7 +144,8 @@ def do_predict(args):
np.set_printoptions(precision=4, suppress=True)
print("Write the predicted results into the output_prediction_file")
with open(args.output_prediction_file, 'w') as fw:
fw = io.open(args.output_prediction_file, 'w', encoding="utf8")
if task_name not in ['atis_slot']:
for index, result in enumerate(all_results):
tags = pred_func(result)
......
......@@ -21,7 +21,6 @@ import os
import sys
import time
import numpy as np
import multiprocessing
import paddle
import paddle.fluid as fluid
......@@ -111,8 +110,7 @@ def do_train(args):
if args.use_cuda:
dev_count = fluid.core.get_cuda_device_count()
else:
dev_count = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
dev_count = int(os.environ.get('CPU_NUM', 1))
batch_generator = processor.data_generator(
batch_size=args.batch_size,
......
......@@ -44,10 +44,10 @@ In our MTL experiments, we use BERT as our shared encoder. The parameters are in
```
1、cd scripts
2、download cased_model_01.tar.gz from link
2、# download cased_model_01.tar.gz from link
3、mkdir cased_model_01 && mv cased_model_01.tar.gz cased_model_01 && cd cased_model_01 && tar -xvf cased_model_01.tar.gz && cd ..
4、python convert_model_params.py --init_tf_checkpoint cased_model_01/model.ckpt --fluid_params_dir params
5、mkdir fluid_models && mv cased_model_01/vocab.txt cased_model_01/bert_config.json params fluid_models
5、mkdir squad2_model && mv cased_model_01/vocab.txt cased_model_01/bert_config.json params squad2_model
```
Alternatively, user can directly **download the parameters that we have converted**:
......
......@@ -21,8 +21,8 @@ import json
import random
import collections
import numpy as np
import tokenization
from batching import prepare_batch_data
from task_reader import tokenization
from task_reader.batching import prepare_batch_data
class MRQAExample(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册