未验证 提交 7e2a66b0 编写于 作者: L Li Fuchen 提交者: GitHub

add license for nlp models (#3390)

add license for nlp models
上级 88d125f2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Deep Attention Matching Network Deep Attention Matching Network
""" """
...@@ -5,6 +18,7 @@ Deep Attention Matching Network ...@@ -5,6 +18,7 @@ Deep Attention Matching Network
import argparse import argparse
import six import six
def parse_args(): def parse_args():
""" """
Deep Attention Matching Network Config Deep Attention Matching Network Config
...@@ -12,14 +26,14 @@ def parse_args(): ...@@ -12,14 +26,14 @@ def parse_args():
parser = argparse.ArgumentParser("DAM Config") parser = argparse.ArgumentParser("DAM Config")
parser.add_argument( parser.add_argument(
'--do_train', '--do_train',
type=bool, type=bool,
default=False, default=False,
help='Whether to perform training.') help='Whether to perform training.')
parser.add_argument( parser.add_argument(
'--do_test', '--do_test',
type=bool, type=bool,
default=False, default=False,
help='Whether to perform training.') help='Whether to perform training.')
parser.add_argument( parser.add_argument(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Evaluation Evaluation
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Layers Layers
""" """
...@@ -77,7 +90,7 @@ def dot_product_attention(query, ...@@ -77,7 +90,7 @@ def dot_product_attention(query,
""" """
logits = fluid.layers.matmul( logits = fluid.layers.matmul(
x=query, y=key, transpose_y=True, alpha=d_key ** (-0.5)) x=query, y=key, transpose_y=True, alpha=d_key**(-0.5))
if (q_mask is not None) and (k_mask is not None): if (q_mask is not None) and (k_mask is not None):
if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[ if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[
...@@ -87,7 +100,7 @@ def dot_product_attention(query, ...@@ -87,7 +100,7 @@ def dot_product_attention(query,
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True) mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = fluid.layers.scale( another_mask = fluid.layers.scale(
mask, mask,
scale=float(2 ** 32 - 1), scale=float(2**32 - 1),
bias=float(-1), bias=float(-1),
bias_after_scale=False) bias_after_scale=False)
if mask_cache is not None: if mask_cache is not None:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Deep Attention Matching Network Deep Attention Matching Network
""" """
...@@ -174,9 +187,8 @@ def train(args): ...@@ -174,9 +187,8 @@ def train(args):
print("device count %d" % dev_count) print("device count %d" % dev_count)
print("theoretical memory usage: ") print("theoretical memory usage: ")
print( print(fluid.contrib.memory_usage(
fluid.contrib.memory_usage( program=train_program, batch_size=args.batch_size))
program=train_program, batch_size=args.batch_size))
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(train_startup) exe.run(train_startup)
...@@ -247,9 +259,8 @@ def train(args): ...@@ -247,9 +259,8 @@ def train(args):
if (args.save_path is not None) and (step % save_step == 0): if (args.save_path is not None) and (step % save_step == 0):
save_path = os.path.join(args.save_path, "step_" + str(step)) save_path = os.path.join(args.save_path, "step_" + str(step))
print("Save model at step %d ... " % step) print("Save model at step %d ... " % step)
print( print(time.strftime('%Y-%m-%d %H:%M:%S',
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path, 'score.' + str(step)) score_path = os.path.join(args.save_path, 'score.' + str(step))
...@@ -294,9 +305,8 @@ def train(args): ...@@ -294,9 +305,8 @@ def train(args):
save_path = os.path.join(args.save_path, save_path = os.path.join(args.save_path,
"step_" + str(step)) "step_" + str(step))
print("Save model at step %d ... " % step) print("Save model at step %d ... " % step)
print( print(time.strftime('%Y-%m-%d %H:%M:%S',
time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path, score_path = os.path.join(args.save_path,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Deep Attention Matching Network Deep Attention Matching Network
""" """
...@@ -12,6 +25,7 @@ class Net(object): ...@@ -12,6 +25,7 @@ class Net(object):
""" """
Deep attention matching network Deep attention matching network
""" """
def __init__(self, max_turn_num, max_turn_len, vocab_size, emb_size, def __init__(self, max_turn_num, max_turn_len, vocab_size, emb_size,
stack_num, channel1_num, channel2_num): stack_num, channel1_num, channel2_num):
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Reader for deep attention matching network Reader for deep attention matching network
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Utils Utils
""" """
...@@ -20,7 +33,7 @@ def mkdir(path): ...@@ -20,7 +33,7 @@ def mkdir(path):
""" """
Mkdir Mkdir
""" """
if not os.path.isdir(path): if not os.path.isdir(path):
if os.path.split(path)[0]: if os.path.split(path)[0]:
mkdir(os.path.split(path)[0]) mkdir(os.path.split(path)[0])
else: else:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python #!/usr/bin/python
#-*- coding:utf-8 -*- #-*- coding:utf-8 -*-
...@@ -25,9 +38,8 @@ def compute_paragraph_score(sample): ...@@ -25,9 +38,8 @@ def compute_paragraph_score(sample):
doc['segmented_paragraphs_scores'] = [] doc['segmented_paragraphs_scores'] = []
for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']): for p_idx, para_tokens in enumerate(doc['segmented_paragraphs']):
if len(question) > 0: if len(question) > 0:
related_score = metric_max_over_ground_truths(f1_score, related_score = metric_max_over_ground_truths(
para_tokens, f1_score, para_tokens, [question])
[question])
else: else:
related_score = 0.0 related_score = 0.0
doc['segmented_paragraphs_scores'].append(related_score) doc['segmented_paragraphs_scores'].append(related_score)
...@@ -63,7 +75,7 @@ def dup_remove(doc): ...@@ -63,7 +75,7 @@ def dup_remove(doc):
prev_del_num = 0 prev_del_num = 0
del_num = 0 del_num = 0
for p_idx in del_ids: for p_idx in del_ids:
if p_idx < para_id: if p_idx < para_id:
prev_del_num += 1 prev_del_num += 1
del doc["segmented_paragraphs"][p_idx - del_num] del doc["segmented_paragraphs"][p_idx - del_num]
del doc["segmented_paragraphs_scores"][p_idx - del_num] del doc["segmented_paragraphs_scores"][p_idx - del_num]
...@@ -142,7 +154,8 @@ def paragraph_selection(sample, mode): ...@@ -142,7 +154,8 @@ def paragraph_selection(sample, mode):
para_infos = [] para_infos = []
for p_idx, (para_tokens, para_scores) in \ for p_idx, (para_tokens, para_scores) in \
enumerate(zip(doc['segmented_paragraphs'], doc['segmented_paragraphs_scores'])): enumerate(zip(doc['segmented_paragraphs'], doc['segmented_paragraphs_scores'])):
para_infos.append((para_tokens, para_scores, len(para_tokens), p_idx)) para_infos.append(
(para_tokens, para_scores, len(para_tokens), p_idx))
para_infos.sort(key=lambda x: (-x[1], x[2])) para_infos.sort(key=lambda x: (-x[1], x[2]))
topN_idx = [] topN_idx = []
for para_info in para_infos[:topN]: for para_info in para_infos[:topN]:
...@@ -158,7 +171,7 @@ def paragraph_selection(sample, mode): ...@@ -158,7 +171,7 @@ def paragraph_selection(sample, mode):
break break
if doc_id == d_idx and id == para_id and mode == "train": if doc_id == d_idx and id == para_id and mode == "train":
continue continue
total_len += 1 + doc['paragraphs_length'][id] total_len += 1 + doc['paragraphs_length'][id]
final_idx.append(id) final_idx.append(id)
total_segmented_content = copy.deepcopy(segmented_title) total_segmented_content = copy.deepcopy(segmented_title)
final_idx.sort() final_idx.sort()
...@@ -168,7 +181,8 @@ def paragraph_selection(sample, mode): ...@@ -168,7 +181,8 @@ def paragraph_selection(sample, mode):
incre_len += 1 + doc['paragraphs_length'][id] incre_len += 1 + doc['paragraphs_length'][id]
if doc_id == d_idx and id == para_id: if doc_id == d_idx and id == para_id:
incre_len += 1 incre_len += 1
total_segmented_content += [splitter] + doc['segmented_paragraphs'][id] total_segmented_content += [splitter] + doc['segmented_paragraphs'][
id]
if doc_id == d_idx: if doc_id == d_idx:
answer_start = incre_len + sample['answer_spans'][0][0] answer_start = incre_len + sample['answer_spans'][0][0]
answer_end = incre_len + sample['answer_spans'][0][1] answer_end = incre_len + sample['answer_spans'][0][1]
...@@ -191,9 +205,9 @@ if __name__ == "__main__": ...@@ -191,9 +205,9 @@ if __name__ == "__main__":
try: try:
sample = json.loads(line, encoding='utf8') sample = json.loads(line, encoding='utf8')
except: except:
print >>sys.stderr, "Invalid input json format - '{}' will be ignored".format(line) print >> sys.stderr, "Invalid input json format - '{}' will be ignored".format(
line)
continue continue
compute_paragraph_score(sample) compute_paragraph_score(sample)
paragraph_selection(sample, mode) paragraph_selection(sample, mode)
print(json.dumps(sample, encoding='utf8', ensure_ascii=False)) print(json.dumps(sample, encoding='utf8', ensure_ascii=False))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding=utf8 #coding=utf8
import os, sys, json import os, sys, json
import nltk import nltk
def _nltk_tokenize(sequence): def _nltk_tokenize(sequence):
tokens = nltk.word_tokenize(sequence) tokens = nltk.word_tokenize(sequence)
...@@ -11,10 +25,12 @@ def _nltk_tokenize(sequence): ...@@ -11,10 +25,12 @@ def _nltk_tokenize(sequence):
token_words = [] token_words = []
for token in tokens: for token in tokens:
cur_char_offset = sequence.find(token, cur_char_offset) cur_char_offset = sequence.find(token, cur_char_offset)
token_offsets.append([cur_char_offset, cur_char_offset + len(token) - 1]) token_offsets.append(
[cur_char_offset, cur_char_offset + len(token) - 1])
token_words.append(token) token_words.append(token)
return token_offsets, token_words return token_offsets, token_words
def segment(input_js): def segment(input_js):
_, input_js['segmented_question'] = _nltk_tokenize(input_js['question']) _, input_js['segmented_question'] = _nltk_tokenize(input_js['question'])
for doc_id, doc in enumerate(input_js['documents']): for doc_id, doc in enumerate(input_js['documents']):
...@@ -36,7 +52,7 @@ if __name__ == '__main__': ...@@ -36,7 +52,7 @@ if __name__ == '__main__':
exit() exit()
nltk.download('punkt') nltk.download('punkt')
for line in open(sys.argv[1]): for line in open(sys.argv[1]):
dureader_js = json.loads(line.strip()) dureader_js = json.loads(line.strip())
segment(dureader_js) segment(dureader_js)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding=utf8 #coding=utf8
import sys import sys
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import json import json
import pandas as pd import pandas as pd
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 3: if len(sys.argv) != 3:
...@@ -11,4 +24,4 @@ if __name__ == '__main__': ...@@ -11,4 +24,4 @@ if __name__ == '__main__':
df = pd.read_json(infile) df = pd.read_json(infile)
with open(outfile, 'w') as f: with open(outfile, 'w') as f:
for row in df.iterrows(): for row in df.iterrows():
f.write(row[1].to_json() + '\n') f.write(row[1].to_json() + '\n')
\ No newline at end of file
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object): class TrainTaskConfig(object):
# support both CPU and GPU now. # support both CPU and GPU now.
use_gpu = True use_gpu = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be # The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the # consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder. # sequence_expand op used in beamsearch decoder.
...@@ -65,43 +78,37 @@ input_descs = { ...@@ -65,43 +78,37 @@ input_descs = {
# Names of word embedding table which might be reused for weight sharing. # Names of word embedding table which might be reused for weight sharing.
word_emb_param_names = ( word_emb_param_names = (
"src_word_emb_table", "src_word_emb_table",
"trg_word_emb_table", "trg_word_emb_table", )
)
phone_emb_param_name = "phone_emb_table" phone_emb_param_name = "phone_emb_table"
# Names of position encoding table which will be initialized externally. # Names of position encoding table which will be initialized externally.
pos_enc_param_names = ( pos_enc_param_names = (
"src_pos_enc_table", "src_pos_enc_table",
"trg_pos_enc_table", "trg_pos_enc_table", )
)
# separated inputs for different usages. # separated inputs for different usages.
encoder_data_input_fields = ( encoder_data_input_fields = (
"src_word", "src_word",
"src_pos", "src_pos",
"src_slf_attn_bias", "src_slf_attn_bias",
"src_phone", "src_phone",
"src_phone_mask", "src_phone_mask", )
)
decoder_data_input_fields = ( decoder_data_input_fields = (
"trg_word", "trg_word",
"trg_pos", "trg_pos",
"trg_slf_attn_bias", "trg_slf_attn_bias",
"trg_src_attn_bias", "trg_src_attn_bias",
"enc_output", "enc_output", )
)
label_data_input_fields = ( label_data_input_fields = (
"lbl_word", "lbl_word",
"lbl_weight", "lbl_weight", )
)
# In fast decoder, trg_pos (only containing the current time step) is generated # In fast decoder, trg_pos (only containing the current time step) is generated
# by ops and trg_slf_attn_bias is not needed. # by ops and trg_slf_attn_bias is not needed.
fast_decoder_data_input_fields = ( fast_decoder_data_input_fields = (
"trg_word", "trg_word",
"init_score", "init_score",
"init_idx", "init_idx",
"trg_src_attn_bias", "trg_src_attn_bias", )
)
# Set seed for CE # Set seed for CE
dropout_seed = None dropout_seed = None
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import multiprocessing import multiprocessing
...@@ -86,10 +99,8 @@ def parse_args(): ...@@ -86,10 +99,8 @@ def parse_args():
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath) phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath)
dict_args = [ dict_args = [
"src_vocab_size", "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(src_dict)), "trg_vocab_size", str(len(trg_dict)), "phone_vocab_size", str(len(phone_dict)), "bos_idx",
str(len(trg_dict)), "phone_vocab_size",
str(len(phone_dict)), "bos_idx",
str(src_dict[args.special_token[0]]), "eos_idx", str(src_dict[args.special_token[0]]), "eos_idx",
str(src_dict[args.special_token[1]]), "unk_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]]) str(src_dict[args.special_token[2]])
...@@ -147,10 +158,10 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx, ...@@ -147,10 +158,10 @@ def prepare_batch_input(insts, data_input_names, src_pad_idx, phone_pad_idx,
# beamsearch_op must use tensors with lod # beamsearch_op must use tensors with lod
init_score = to_lodtensor( init_score = to_lodtensor(
np.zeros_like(trg_word, dtype="float32").reshape(-1, 1), place, np.zeros_like(
[range(trg_word.shape[0] + 1)] * 2) trg_word, dtype="float32").reshape(-1, 1),
trg_word = to_lodtensor(trg_word, place, place, [range(trg_word.shape[0] + 1)] * 2)
[range(trg_word.shape[0] + 1)] * 2) trg_word = to_lodtensor(trg_word, place, [range(trg_word.shape[0] + 1)] * 2)
init_idx = np.asarray(range(len(insts)), dtype="int32") init_idx = np.asarray(range(len(insts)), dtype="int32")
data_input_dict = dict( data_input_dict = dict(
...@@ -315,7 +326,8 @@ def fast_infer(args): ...@@ -315,7 +326,8 @@ def fast_infer(args):
sub_start = seq_ids.lod()[1][start + j] sub_start = seq_ids.lod()[1][start + j]
sub_end = seq_ids.lod()[1][start + j + 1] sub_end = seq_ids.lod()[1][start + j + 1]
hyps[i].append(" ".join([ hyps[i].append(" ".join([
trg_idx2word[idx] for idx in post_process_seq( trg_idx2word[idx]
for idx in post_process_seq(
np.array(seq_ids)[sub_start:sub_end]) np.array(seq_ids)[sub_start:sub_end])
])) ]))
scores[i].append(np.array(seq_scores)[sub_end - 1]) scores[i].append(np.array(seq_scores)[sub_end - 1])
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial from functools import partial
import numpy as np import numpy as np
...@@ -51,12 +64,12 @@ def position_encoding_init(n_position, d_pos_vec): ...@@ -51,12 +64,12 @@ def position_encoding_init(n_position, d_pos_vec):
channels = d_pos_vec channels = d_pos_vec
position = np.arange(n_position) position = np.arange(n_position)
num_timescales = channels // 2 num_timescales = channels // 2
log_timescale_increment = ( log_timescale_increment = (np.log(float(1e4) / float(1)) /
np.log(float(1e4) / float(1)) / (num_timescales - 1)) (num_timescales - 1))
inv_timescales = np.exp( inv_timescales = np.exp(np.arange(
np.arange(num_timescales)) * -log_timescale_increment num_timescales)) * -log_timescale_increment
scaled_time = np.expand_dims(position, 1) * np.expand_dims( scaled_time = np.expand_dims(position, 1) * np.expand_dims(inv_timescales,
inv_timescales, 0) 0)
signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1) signal = np.concatenate([np.sin(scaled_time), np.cos(scaled_time)], axis=1)
signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant') signal = np.pad(signal, [[0, 0], [0, np.mod(channels, 2)]], 'constant')
position_enc = signal position_enc = signal
...@@ -91,17 +104,15 @@ def multi_head_attention(queries, ...@@ -91,17 +104,15 @@ def multi_head_attention(queries,
""" """
Add linear projection to queries, keys, and values. Add linear projection to queries, keys, and values.
""" """
q = layers.fc( q = layers.fc(input=queries,
input=queries, size=d_key * n_head,
size=d_key * n_head, bias_attr=False,
bias_attr=False, num_flatten_dims=2)
num_flatten_dims=2)
# For encoder-decoder attention in inference, insert the ops and vars # For encoder-decoder attention in inference, insert the ops and vars
# into global block to use as cache among beam search. # into global block to use as cache among beam search.
fc_layer = wrap_layer_with_block( fc_layer = wrap_layer_with_block(
layers.fc, layers.fc, fluid.default_main_program().current_block(
fluid.default_main_program().current_block(). ).parent_idx) if cache is not None and static_kv else layers.fc
parent_idx) if cache is not None and static_kv else layers.fc
k = fc_layer( k = fc_layer(
input=keys, input=keys,
size=d_key * n_head, size=d_key * n_head,
...@@ -132,12 +143,12 @@ def multi_head_attention(queries, ...@@ -132,12 +143,12 @@ def multi_head_attention(queries,
# into global block to use as cache among beam search. # into global block to use as cache among beam search.
reshape_layer = wrap_layer_with_block( reshape_layer = wrap_layer_with_block(
layers.reshape, layers.reshape,
fluid.default_main_program().current_block(). fluid.default_main_program().current_block(
parent_idx) if cache is not None and static_kv else layers.reshape ).parent_idx) if cache is not None and static_kv else layers.reshape
transpose_layer = wrap_layer_with_block( transpose_layer = wrap_layer_with_block(
layers.transpose, layers.transpose,
fluid.default_main_program().current_block().parent_idx fluid.default_main_program().current_block().
) if cache is not None and static_kv else layers.transpose parent_idx) if cache is not None and static_kv else layers.transpose
reshaped_k = reshape_layer( reshaped_k = reshape_layer(
x=keys, shape=[0, 0, n_head, d_key], inplace=True) x=keys, shape=[0, 0, n_head, d_key], inplace=True)
k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3]) k = transpose_layer(x=reshaped_k, perm=[0, 2, 1, 3])
...@@ -214,8 +225,10 @@ def multi_head_attention(queries, ...@@ -214,8 +225,10 @@ def multi_head_attention(queries,
out = __combine_heads(ctx_multiheads) out = __combine_heads(ctx_multiheads)
# Project back to the model size. # Project back to the model size.
proj_out = layers.fc( proj_out = layers.fc(input=out,
input=out, size=d_model, bias_attr=False, num_flatten_dims=2) size=d_model,
bias_attr=False,
num_flatten_dims=2)
return proj_out return proj_out
...@@ -225,14 +238,13 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate): ...@@ -225,14 +238,13 @@ def positionwise_feed_forward(x, d_inner_hid, d_hid, dropout_rate):
This module consists of two linear transformations with a ReLU activation This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically. in between, which is applied to each position separately and identically.
""" """
hidden = layers.fc( hidden = layers.fc(input=x,
input=x, size=d_inner_hid, num_flatten_dims=2, act="relu") size=d_inner_hid,
num_flatten_dims=2,
act="relu")
if dropout_rate: if dropout_rate:
hidden = layers.dropout( hidden = layers.dropout(
hidden, hidden, dropout_prob=dropout_rate, seed=dropout_seed, is_test=False)
dropout_prob=dropout_rate,
seed=dropout_seed,
is_test=False)
out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2) out = layers.fc(input=hidden, size=d_hid, num_flatten_dims=2)
return out return out
...@@ -313,8 +325,7 @@ def prepare_encoder(src_word, ...@@ -313,8 +325,7 @@ def prepare_encoder(src_word,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name=pos_enc_param_names[0], trainable=False)) name=pos_enc_param_names[0], trainable=False))
src_pos_enc.stop_gradient = True src_pos_enc.stop_gradient = True
enc_input = ( enc_input = (1 - beta) * src_word_emb + beta * mean_phone_emb + src_pos_enc
1 - beta) * src_word_emb + beta * mean_phone_emb + src_pos_enc
return layers.dropout( return layers.dropout(
enc_input, dropout_prob=dropout_rate, seed=dropout_seed, enc_input, dropout_prob=dropout_rate, seed=dropout_seed,
is_test=False) if dropout_rate else enc_input is_test=False) if dropout_rate else enc_input
...@@ -374,8 +385,8 @@ def encoder_layer(enc_input, ...@@ -374,8 +385,8 @@ def encoder_layer(enc_input,
""" """
attn_output = multi_head_attention( attn_output = multi_head_attention(
pre_process_layer(enc_input, preprocess_cmd, pre_process_layer(enc_input, preprocess_cmd,
prepostprocess_dropout), None, None, attn_bias, prepostprocess_dropout), None, None, attn_bias, d_key,
d_key, d_value, d_model, n_head, attention_dropout) d_value, d_model, n_head, attention_dropout)
attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd, attn_output = post_process_layer(enc_input, attn_output, postprocess_cmd,
prepostprocess_dropout) prepostprocess_dropout)
ffd_output = positionwise_feed_forward( ffd_output = positionwise_feed_forward(
...@@ -415,8 +426,7 @@ def encoder(enc_input, ...@@ -415,8 +426,7 @@ def encoder(enc_input,
attention_dropout, attention_dropout,
relu_dropout, relu_dropout,
preprocess_cmd, preprocess_cmd,
postprocess_cmd, postprocess_cmd, )
)
enc_input = enc_output enc_input = enc_output
enc_output = pre_process_layer(enc_output, preprocess_cmd, enc_output = pre_process_layer(enc_output, preprocess_cmd,
prepostprocess_dropout) prepostprocess_dropout)
...@@ -459,8 +469,7 @@ def decoder_layer(dec_input, ...@@ -459,8 +469,7 @@ def decoder_layer(dec_input,
dec_input, dec_input,
slf_attn_output, slf_attn_output,
postprocess_cmd, postprocess_cmd,
prepostprocess_dropout, prepostprocess_dropout, )
)
enc_attn_output = multi_head_attention( enc_attn_output = multi_head_attention(
pre_process_layer(slf_attn_output, preprocess_cmd, pre_process_layer(slf_attn_output, preprocess_cmd,
prepostprocess_dropout), prepostprocess_dropout),
...@@ -479,21 +488,18 @@ def decoder_layer(dec_input, ...@@ -479,21 +488,18 @@ def decoder_layer(dec_input,
slf_attn_output, slf_attn_output,
enc_attn_output, enc_attn_output,
postprocess_cmd, postprocess_cmd,
prepostprocess_dropout, prepostprocess_dropout, )
)
ffd_output = positionwise_feed_forward( ffd_output = positionwise_feed_forward(
pre_process_layer(enc_attn_output, preprocess_cmd, pre_process_layer(enc_attn_output, preprocess_cmd,
prepostprocess_dropout), prepostprocess_dropout),
d_inner_hid, d_inner_hid,
d_model, d_model,
relu_dropout, relu_dropout, )
)
dec_output = post_process_layer( dec_output = post_process_layer(
enc_attn_output, enc_attn_output,
ffd_output, ffd_output,
postprocess_cmd, postprocess_cmd,
prepostprocess_dropout, prepostprocess_dropout, )
)
return dec_output return dec_output
...@@ -632,8 +638,7 @@ def transformer(src_vocab_size, ...@@ -632,8 +638,7 @@ def transformer(src_vocab_size,
postprocess_cmd, postprocess_cmd,
weight_sharing, weight_sharing,
beta, beta,
enc_inputs, enc_inputs, )
)
predict = wrap_decoder( predict = wrap_decoder(
trg_vocab_size, trg_vocab_size,
...@@ -651,14 +656,14 @@ def transformer(src_vocab_size, ...@@ -651,14 +656,14 @@ def transformer(src_vocab_size,
postprocess_cmd, postprocess_cmd,
weight_sharing, weight_sharing,
dec_inputs, dec_inputs,
enc_output, enc_output, )
)
# Padding index do not contribute to the total loss. The weights is used to # Padding index do not contribute to the total loss. The weights is used to
# cancel padding index in calculating the loss. # cancel padding index in calculating the loss.
if label_smooth_eps: if label_smooth_eps:
label = layers.label_smooth( label = layers.label_smooth(
label=layers.one_hot(input=label, depth=trg_vocab_size), label=layers.one_hot(
input=label, depth=trg_vocab_size),
epsilon=label_smooth_eps) epsilon=label_smooth_eps)
cost = layers.softmax_with_cross_entropy( cost = layers.softmax_with_cross_entropy(
...@@ -730,8 +735,7 @@ def wrap_encoder(src_vocab_size, ...@@ -730,8 +735,7 @@ def wrap_encoder(src_vocab_size,
attention_dropout, attention_dropout,
relu_dropout, relu_dropout,
preprocess_cmd, preprocess_cmd,
postprocess_cmd, postprocess_cmd, )
)
return enc_output return enc_output
...@@ -803,8 +807,9 @@ def wrap_decoder(trg_vocab_size, ...@@ -803,8 +807,9 @@ def wrap_decoder(trg_vocab_size,
word_emb_param_names[0]), word_emb_param_names[0]),
transpose_y=True) transpose_y=True)
else: else:
predict = layers.fc( predict = layers.fc(input=dec_output,
input=dec_output, size=trg_vocab_size, bias_attr=False) size=trg_vocab_size,
bias_attr=False)
if dec_inputs is None: if dec_inputs is None:
# Return probs for independent decoder program. # Return probs for independent decoder program.
predict = layers.softmax(predict) predict = layers.softmax(predict)
...@@ -879,8 +884,7 @@ def fast_decode(src_vocab_size, ...@@ -879,8 +884,7 @@ def fast_decode(src_vocab_size,
force_cpu=True) force_cpu=True)
step_idx = layers.fill_constant( step_idx = layers.fill_constant(
shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True) shape=[1], dtype=start_tokens.dtype, value=0, force_cpu=True)
cond = layers.less_than( cond = layers.less_than(x=step_idx, y=max_len) # default force_cpu=True
x=step_idx, y=max_len) # default force_cpu=True
while_op = layers.While(cond) while_op = layers.While(cond)
# array states will be stored for each step. # array states will be stored for each step.
ids = layers.array_write( ids = layers.array_write(
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob import glob
import six import six
import os import os
...@@ -302,9 +315,8 @@ class DataReader(object): ...@@ -302,9 +315,8 @@ class DataReader(object):
f = tarfile.open(fpaths[0], "r") f = tarfile.open(fpaths[0], "r")
for line in f.extractfile(tar_fname): for line in f.extractfile(tar_fname):
fields = line.strip("\n").split(self._field_delimiter) fields = line.strip("\n").split(self._field_delimiter)
if (not self._only_src if (not self._only_src and len(fields) == 2) or (
and len(fields) == 2) or (self._only_src self._only_src and len(fields) == 1):
and len(fields) == 1):
yield fields yield fields
else: else:
for fpath in fpaths: for fpath in fpaths:
...@@ -381,5 +393,5 @@ class DataReader(object): ...@@ -381,5 +393,5 @@ class DataReader(object):
for idx in batch_ids] for idx in batch_ids]
else: else:
yield [(self._src_seq_ids[idx], self._src_phone_ids[idx], yield [(self._src_seq_ids[idx], self._src_phone_ids[idx],
self._trg_seq_ids[idx][:-1], self._trg_seq_ids[idx][:-1], self._trg_seq_ids[idx][1:])
self._trg_seq_ids[idx][1:]) for idx in batch_ids] for idx in batch_ids]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import copy import copy
...@@ -141,10 +154,8 @@ def parse_args(): ...@@ -141,10 +154,8 @@ def parse_args():
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath) trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath) phone_dict = reader.DataReader.load_dict(args.phoneme_vocab_fpath)
dict_args = [ dict_args = [
"src_vocab_size", "src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(src_dict)), "trg_vocab_size", str(len(trg_dict)), "phone_vocab_size", str(len(phone_dict)), "bos_idx",
str(len(trg_dict)), "phone_vocab_size",
str(len(phone_dict)), "bos_idx",
str(src_dict[args.special_token[0]]), "eos_idx", str(src_dict[args.special_token[0]]), "eos_idx",
str(src_dict[args.special_token[1]]), "unk_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]]) str(src_dict[args.special_token[2]])
...@@ -157,8 +168,8 @@ def parse_args(): ...@@ -157,8 +168,8 @@ def parse_args():
def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints, def append_nccl2_prepare(startup_prog, trainer_id, worker_endpoints,
current_endpoint): current_endpoint):
assert (trainer_id >= 0 and len(worker_endpoints) > 1 assert (trainer_id >= 0 and len(worker_endpoints) > 1 and
and current_endpoint in worker_endpoints) current_endpoint in worker_endpoints)
eps = copy.deepcopy(worker_endpoints) eps = copy.deepcopy(worker_endpoints)
eps.remove(current_endpoint) eps.remove(current_endpoint)
nccl_id_var = startup_prog.global_block().create_var( nccl_id_var = startup_prog.global_block().create_var(
...@@ -189,8 +200,8 @@ def pad_phoneme_data(phoneme_seqs, pad_idx, max_seq_len): ...@@ -189,8 +200,8 @@ def pad_phoneme_data(phoneme_seqs, pad_idx, max_seq_len):
batch_size = len(phoneme_seqs) batch_size = len(phoneme_seqs)
phoneme_data = pad_idx * np.ones( phoneme_data = pad_idx * np.ones(
(batch_size, max_seq_len, max_ph_seq_len), dtype=np.int64) (batch_size, max_seq_len, max_ph_seq_len), dtype=np.int64)
phoneme_mask = np.zeros((batch_size, max_seq_len, max_ph_seq_len), phoneme_mask = np.zeros(
dtype=np.int64) (batch_size, max_seq_len, max_ph_seq_len), dtype=np.int64)
for i in range(batch_size): for i in range(batch_size):
cur_ph_seq = phoneme_seqs[i] cur_ph_seq = phoneme_seqs[i]
...@@ -237,17 +248,16 @@ def pad_batch_data(insts, ...@@ -237,17 +248,16 @@ def pad_batch_data(insts,
if is_target: if is_target:
# This is used to avoid attention on paddings and subsequent # This is used to avoid attention on paddings and subsequent
# words. # words.
slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, slf_attn_bias_data = np.ones((inst_data.shape[0], max_len, max_len))
max_len))
slf_attn_bias_data = np.triu(slf_attn_bias_data, slf_attn_bias_data = np.triu(slf_attn_bias_data,
1).reshape([-1, 1, max_len, max_len]) 1).reshape([-1, 1, max_len, max_len])
slf_attn_bias_data = np.tile(slf_attn_bias_data, slf_attn_bias_data = np.tile(slf_attn_bias_data,
[1, n_head, 1, 1]) * [-1e9] [1, n_head, 1, 1]) * [-1e9]
else: else:
# This is used to avoid attention on paddings. # This is used to avoid attention on paddings.
slf_attn_bias_data = np.array( slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
[[0] * len(inst) + [-1e9] * (max_len - len(inst)) (max_len - len(inst))
for inst in insts]) for inst in insts])
slf_attn_bias_data = np.tile( slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, 1, max_len]), slf_attn_bias_data.reshape([-1, 1, 1, max_len]),
[1, n_head, max_len, 1]) [1, n_head, max_len, 1])
...@@ -359,8 +369,8 @@ def prepare_data_generator(args, ...@@ -359,8 +369,8 @@ def prepare_data_generator(args,
for item in data_reader(): for item in data_reader():
inst_num_per_part = len(item) // count inst_num_per_part = len(item) // count
for i in range(count): for i in range(count):
yield item[inst_num_per_part * i:inst_num_per_part * yield item[inst_num_per_part * i:inst_num_per_part * (i + 1
(i + 1)] )]
return __impl__ return __impl__
...@@ -401,8 +411,8 @@ def prepare_feed_dict_list(data_generator, init_flag, count): ...@@ -401,8 +411,8 @@ def prepare_feed_dict_list(data_generator, init_flag, count):
feed_dict_list.append(pos_enc_tables) feed_dict_list.append(pos_enc_tables)
else: else:
feed_dict_list[idx] = dict( feed_dict_list[idx] = dict(
list(pos_enc_tables.items()) + list(pos_enc_tables.items()) + list(feed_dict_list[idx]
list(feed_dict_list[idx].items())) .items()))
return feed_dict_list if len(feed_dict_list) == count else None return feed_dict_list if len(feed_dict_list) == count else None
...@@ -487,11 +497,10 @@ def test_context(exe, train_exe, dev_count): ...@@ -487,11 +497,10 @@ def test_context(exe, train_exe, dev_count):
data_generator = test_data() data_generator = test_data()
while True: while True:
try: try:
feed_dict_list = prepare_feed_dict_list( feed_dict_list = prepare_feed_dict_list(data_generator, False,
data_generator, False, dev_count) dev_count)
outs = test_exe.run( outs = test_exe.run(fetch_list=[sum_cost.name, token_num.name],
fetch_list=[sum_cost.name, token_num.name], feed=feed_dict_list)
feed=feed_dict_list)
except (StopIteration, fluid.core.EOFException): except (StopIteration, fluid.core.EOFException):
# The current pass is over. # The current pass is over.
if args.use_py_reader: if args.use_py_reader:
...@@ -562,10 +571,10 @@ def train_loop(exe, ...@@ -562,10 +571,10 @@ def train_loop(exe,
# the best cross-entropy value with label smoothing # the best cross-entropy value with label smoothing
loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log( loss_normalizer = -((1. - TrainTaskConfig.label_smooth_eps) * np.log(
(1. - TrainTaskConfig.label_smooth_eps)) + (1. - TrainTaskConfig.label_smooth_eps
TrainTaskConfig.label_smooth_eps * )) + TrainTaskConfig.label_smooth_eps *
np.log(TrainTaskConfig.label_smooth_eps / np.log(TrainTaskConfig.label_smooth_eps / (
(ModelHyperParams.trg_vocab_size - 1) + 1e-20)) ModelHyperParams.trg_vocab_size - 1) + 1e-20))
step_idx = 0 step_idx = 0
init_flag = True init_flag = True
...@@ -583,8 +592,8 @@ def train_loop(exe, ...@@ -583,8 +592,8 @@ def train_loop(exe,
batch_id = 0 batch_id = 0
while True: while True:
try: try:
feed_dict_list = prepare_feed_dict_list( feed_dict_list = prepare_feed_dict_list(data_generator,
data_generator, init_flag, dev_count) init_flag, dev_count)
outs = train_exe.run( outs = train_exe.run(
fetch_list=[sum_cost.name, token_num.name] fetch_list=[sum_cost.name, token_num.name]
if step_idx % args.fetch_steps == 0 else [], if step_idx % args.fetch_steps == 0 else [],
...@@ -609,12 +618,11 @@ def train_loop(exe, ...@@ -609,12 +618,11 @@ def train_loop(exe,
else: else:
logging.info( logging.info(
"step_idx: %d, epoch: %d, batch: %d, avg loss: %f, " "step_idx: %d, epoch: %d, batch: %d, avg loss: %f, "
"normalized loss: %f, ppl: %f, speed: %.2f step/s" "normalized loss: %f, ppl: %f, speed: %.2f step/s" %
% (step_idx, pass_id, batch_id, total_avg_cost, (step_idx, pass_id, batch_id, total_avg_cost,
total_avg_cost - loss_normalizer, total_avg_cost - loss_normalizer, np.exp(
np.exp([min(total_avg_cost, 100) [min(total_avg_cost, 100)]),
]), args.fetch_steps / args.fetch_steps / (time.time() - avg_batch_time)))
(time.time() - avg_batch_time)))
avg_batch_time = time.time() avg_batch_time = time.time()
if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0: if step_idx % TrainTaskConfig.save_freq == 0 and step_idx > 0:
...@@ -643,8 +651,9 @@ def train_loop(exe, ...@@ -643,8 +651,9 @@ def train_loop(exe,
val_avg_cost, val_ppl = test() val_avg_cost, val_ppl = test()
logging.info( logging.info(
"epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f," "epoch: %d, val avg loss: %f, val normalized loss: %f, val ppl: %f,"
" consumed %fs" % (pass_id, val_avg_cost, val_avg_cost - " consumed %fs" % (pass_id, val_avg_cost,
loss_normalizer, val_ppl, time_consumed)) val_avg_cost - loss_normalizer, val_ppl,
time_consumed))
else: else:
logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed)) logging.info("epoch: %d, consumed %fs" % (pass_id, time_consumed))
if not args.enable_ce: if not args.enable_ce:
...@@ -734,8 +743,8 @@ def train(args): ...@@ -734,8 +743,8 @@ def train(args):
if args.local: if args.local:
logging.info("local start_up:") logging.info("local start_up:")
train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, train_loop(exe, train_prog, startup_prog, dev_count, sum_cost, avg_cost,
avg_cost, token_num, predict, pyreader) token_num, predict, pyreader)
else: else:
if args.update_method == "nccl2": if args.update_method == "nccl2":
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
......
import os
filePath = os.getcwd()
def get_all_files(dir):
fileDirList = []
for root, dirs, files in os.walk(dir):
for file in files:
file_path = os.path.join(root, file)
fileDirList.append(file_path)
for dir in dirs:
dir_path = os.path.join(root, dir)
get_all_files(dir_path)
return fileDirList
fileDirList = get_all_files(filePath)
for code in fileDirList:
split = os.path.splitext(code)
if (split[1] == '.py' and not '__init__' in split[0] and
not '_ce' in split[0]):
with open(code, 'r') as fz:
content = fz.read()
if content.find('Copyright') >= 0:
fz.close()
continue
else:
string = "# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.\n" \
"#\n" \
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n" \
"# you may not use this file except in compliance with the License.\n" \
"# You may obtain a copy of the License at\n" \
"#\n" \
"# http://www.apache.org/licenses/LICENSE-2.0\n" \
"#\n" \
"# Unless required by applicable law or agreed to in writing, software\n" \
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n" \
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n" \
"# See the License for the specific language governing permissions and\n" \
"# limitations under the License.\n"+content
fz.close()
with open(code, 'w') as f:
f.write(string)
print "file %s write success!" % code
f.close()
print "read and write success!"
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
EmoTect config EmoTect config
""" """
...@@ -9,10 +22,12 @@ from __future__ import print_function ...@@ -9,10 +22,12 @@ from __future__ import print_function
import six import six
import json import json
class EmoTectConfig(object): class EmoTectConfig(object):
""" """
EmoTect Config EmoTect Config
""" """
def __init__(self, config_path): def __init__(self, config_path):
self._config_dict = self._parse(config_path) self._config_dict = self._parse(config_path)
...@@ -21,7 +36,8 @@ class EmoTectConfig(object): ...@@ -21,7 +36,8 @@ class EmoTectConfig(object):
with open(config_path) as json_file: with open(config_path) as json_file:
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing emotect model config file '%s'" % config_path) raise IOError("Error in parsing emotect model config file '%s'" %
config_path)
else: else:
return config_dict return config_dict
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
EmoTect Reader, data converters for classification data. EmoTect Reader, data converters for classification data.
""" """
...@@ -10,14 +23,13 @@ import numpy as np ...@@ -10,14 +23,13 @@ import numpy as np
from utils import load_vocab from utils import load_vocab
from utils import data_reader from utils import data_reader
class EmoTectProcessor(object): class EmoTectProcessor(object):
""" """
Processor class for data convertors for EmoTect. Processor class for data convertors for EmoTect.
""" """
def __init__(self,
data_dir, def __init__(self, data_dir, vocab_path, random_seed=None):
vocab_path,
random_seed=None):
self.data_dir = data_dir self.data_dir = data_dir
self.vocab = load_vocab(vocab_path) self.vocab = load_vocab(vocab_path)
self.num_examples = {"train": -1, "dev": -1, "test": -1, "infer": -1} self.num_examples = {"train": -1, "dev": -1, "test": -1, "infer": -1}
...@@ -27,29 +39,33 @@ class EmoTectProcessor(object): ...@@ -27,29 +39,33 @@ class EmoTectProcessor(object):
""" """
Load training examples Load training examples
""" """
return data_reader(os.path.join(self.data_dir, "train.tsv"), return data_reader(
self.vocab, self.num_examples, "train", epoch) os.path.join(self.data_dir, "train.tsv"), self.vocab,
self.num_examples, "train", epoch)
def get_dev_examples(self, data_dir): def get_dev_examples(self, data_dir):
""" """
Load dev examples Load dev examples
""" """
return data_reader(os.path.join(self.data_dir, "dev.tsv"), return data_reader(
self.vocab, self.num_examples, "dev") os.path.join(self.data_dir, "dev.tsv"), self.vocab,
self.num_examples, "dev")
def get_test_examples(self, data_dir): def get_test_examples(self, data_dir):
""" """
Load test examples Load test examples
""" """
return data_reader(os.path.join(self.data_dir, "test.tsv"), return data_reader(
self.vocab, self.num_examples, "test") os.path.join(self.data_dir, "test.tsv"), self.vocab,
self.num_examples, "test")
def get_infer_examples(self, data_dir): def get_infer_examples(self, data_dir):
""" """
Load infer querys Load infer querys
""" """
return data_reader(os.path.join(self.data_dir, "infer.tsv"), return data_reader(
self.vocab, self.num_examples, "infer") os.path.join(self.data_dir, "infer.tsv"), self.vocab,
self.num_examples, "infer")
def get_labels(self): def get_labels(self):
""" """
...@@ -63,7 +79,8 @@ class EmoTectProcessor(object): ...@@ -63,7 +79,8 @@ class EmoTectProcessor(object):
""" """
if phase not in ['train', 'dev', 'test', 'infer']: if phase not in ['train', 'dev', 'test', 'infer']:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test', 'infer'].") "Unknown phase, which should be in ['train', 'dev', 'test', 'infer']."
)
return self.num_examples[phase] return self.num_examples[phase]
def get_train_progress(self): def get_train_progress(self):
...@@ -77,14 +94,18 @@ class EmoTectProcessor(object): ...@@ -77,14 +94,18 @@ class EmoTectProcessor(object):
Generate data for train, dev or test Generate data for train, dev or test
""" """
if phase == "train": if phase == "train":
return paddle.batch(self.get_train_examples(self.data_dir, epoch), batch_size) return paddle.batch(
self.get_train_examples(self.data_dir, epoch), batch_size)
elif phase == "dev": elif phase == "dev":
return paddle.batch(self.get_dev_examples(self.data_dir), batch_size) return paddle.batch(
self.get_dev_examples(self.data_dir), batch_size)
elif phase == "test": elif phase == "test":
return paddle.batch(self.get_test_examples(self.data_dir), batch_size) return paddle.batch(
self.get_test_examples(self.data_dir), batch_size)
elif phase == "infer": elif phase == "infer":
return paddle.batch(self.get_infer_examples(self.data_dir), batch_size) return paddle.batch(
self.get_infer_examples(self.data_dir), batch_size)
else: else:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test', 'infer'].") "Unknown phase, which should be in ['train', 'dev', 'test', 'infer']."
)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Emotion Detection Task Emotion Detection Task
""" """
...@@ -25,37 +38,48 @@ import utils ...@@ -25,37 +38,48 @@ import utils
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.") model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("config_path", str, None, "Path to the json file for EmoTect model config.") model_g.add_arg("config_path", str, None,
model_g.add_arg("init_checkpoint", str, None, "Init checkpoint to resume training from.") "Path to the json file for EmoTect model config.")
model_g.add_arg("init_checkpoint", str, None,
"Init checkpoint to resume training from.")
model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints") model_g.add_arg("output_dir", str, None, "Directory path to save checkpoints")
train_g = utils.ArgumentGroup(parser, "training", "training options.") train_g = utils.ArgumentGroup(parser, "training", "training options.")
train_g.add_arg("epoch", int, 10, "Number of epoches for training.") train_g.add_arg("epoch", int, 10, "Number of epoches for training.")
train_g.add_arg("save_steps", int, 10000, "The steps interval to save checkpoints.") train_g.add_arg("save_steps", int, 10000,
train_g.add_arg("validation_steps", int, 1000, "The steps interval to evaluate model performance.") "The steps interval to save checkpoints.")
train_g.add_arg("validation_steps", int, 1000,
"The steps interval to evaluate model performance.")
train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.") train_g.add_arg("lr", float, 0.002, "The Learning rate value for training.")
log_g = utils.ArgumentGroup(parser, "logging", "logging related") log_g = utils.ArgumentGroup(parser, "logging", "logging related")
log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.") log_g.add_arg("skip_steps", int, 10, "The steps interval to print loss.")
log_g.add_arg("verbose", bool, False, "Whether to output verbose log") log_g.add_arg("verbose", bool, False, "Whether to output verbose log")
data_g = utils.ArgumentGroup(parser, "data", "Data paths, vocab paths and data processing options") data_g = utils.ArgumentGroup(
parser, "data", "Data paths, vocab paths and data processing options")
data_g.add_arg("data_dir", str, None, "Directory path to training data.") data_g.add_arg("data_dir", str, None, "Directory path to training data.")
data_g.add_arg("vocab_path", str, None, "Vocabulary path.") data_g.add_arg("vocab_path", str, None, "Vocabulary path.")
data_g.add_arg("batch_size", int, 256, "Total examples' number in batch for training.") data_g.add_arg("batch_size", int, 256,
"Total examples' number in batch for training.")
data_g.add_arg("random_seed", int, 0, "Random seed.") data_g.add_arg("random_seed", int, 0, "Random seed.")
run_type_g = utils.ArgumentGroup(parser, "run_type", "running type options.") run_type_g = utils.ArgumentGroup(parser, "run_type", "running type options.")
run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.") run_type_g.add_arg("use_cuda", bool, False, "If set, use GPU for training.")
run_type_g.add_arg("task_name", str, None, "The name of task to perform sentiment classification.") run_type_g.add_arg("task_name", str, None,
"The name of task to perform sentiment classification.")
run_type_g.add_arg("do_train", bool, False, "Whether to perform training.") run_type_g.add_arg("do_train", bool, False, "Whether to perform training.")
run_type_g.add_arg("do_val", bool, False, "Whether to perform evaluation.") run_type_g.add_arg("do_val", bool, False, "Whether to perform evaluation.")
run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.") run_type_g.add_arg("do_infer", bool, False, "Whether to perform inference.")
parser.add_argument('--enable_ce', action='store_true', help='If set, run the task with continuous evaluation logs.') parser.add_argument(
'--enable_ce',
action='store_true',
help='If set, run the task with continuous evaluation logs.')
args = parser.parse_args() args = parser.parse_args()
def create_model(args, def create_model(args,
pyreader_name, pyreader_name,
emotect_config, emotect_config,
...@@ -98,11 +122,17 @@ def create_model(args, ...@@ -98,11 +122,17 @@ def create_model(args,
if is_infer: if is_infer:
data = fluid.layers.read_file(pyreader) data = fluid.layers.read_file(pyreader)
probs = network(data, None, emotect_config["vocab_size"], class_dim=num_labels, is_infer=True) probs = network(
data,
None,
emotect_config["vocab_size"],
class_dim=num_labels,
is_infer=True)
return pyreader, probs return pyreader, probs
data, label = fluid.layers.read_file(pyreader) data, label = fluid.layers.read_file(pyreader)
avg_loss, probs = network(data, label, emotect_config["vocab_size"], class_dim=num_labels) avg_loss, probs = network(
data, label, emotect_config["vocab_size"], class_dim=num_labels)
num_seqs = fluid.layers.create_tensor(dtype='int64') num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs) accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs)
return pyreader, avg_loss, accuracy, num_seqs return pyreader, avg_loss, accuracy, num_seqs
...@@ -118,8 +148,8 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): ...@@ -118,8 +148,8 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
while True: while True:
try: try:
np_loss, np_acc, np_num_seqs = exe.run(program=test_program, np_loss, np_acc, np_num_seqs = exe.run(program=test_program,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=False) return_numpy=False)
np_loss = np.array(np_loss) np_loss = np.array(np_loss)
np_acc = np.array(np_acc) np_acc = np.array(np_acc)
np_num_seqs = np.array(np_num_seqs) np_num_seqs = np.array(np_num_seqs)
...@@ -131,8 +161,8 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase): ...@@ -131,8 +161,8 @@ def evaluate(exe, test_program, test_pyreader, fetch_list, eval_phase):
break break
time_end = time.time() time_end = time.time()
print("[%s evaluation] avg loss: %f, avg acc: %f, elapsed time: %f s" % print("[%s evaluation] avg loss: %f, avg acc: %f, elapsed time: %f s" %
(eval_phase, np.sum(total_cost) / np.sum(total_num_seqs), (eval_phase, np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs), time_end - time_begin)) np.sum(total_acc) / np.sum(total_num_seqs), time_end - time_begin))
def infer(exe, infer_program, infer_pyreader, fetch_list, infer_phase): def infer(exe, infer_program, infer_pyreader, fetch_list, infer_phase):
...@@ -141,10 +171,11 @@ def infer(exe, infer_program, infer_pyreader, fetch_list, infer_phase): ...@@ -141,10 +171,11 @@ def infer(exe, infer_program, infer_pyreader, fetch_list, infer_phase):
while True: while True:
try: try:
batch_probs = exe.run(program=infer_program, batch_probs = exe.run(program=infer_program,
fetch_list=fetch_list, fetch_list=fetch_list,
return_numpy=True) return_numpy=True)
for probs in batch_probs[0]: for probs in batch_probs[0]:
print("%d\t%f\t%f\t%f" % (np.argmax(probs), probs[0], probs[1], probs[2])) print("%d\t%f\t%f\t%f" %
(np.argmax(probs), probs[0], probs[1], probs[2]))
except fluid.core.EOFException as e: except fluid.core.EOFException as e:
infer_pyreader.reset() infer_pyreader.reset()
break break
...@@ -165,9 +196,10 @@ def main(args): ...@@ -165,9 +196,10 @@ def main(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
task_name = args.task_name.lower() task_name = args.task_name.lower()
processor = reader.EmoTectProcessor(data_dir=args.data_dir, processor = reader.EmoTectProcessor(
vocab_path=args.vocab_path, data_dir=args.data_dir,
random_seed=args.random_seed) vocab_path=args.vocab_path,
random_seed=args.random_seed)
num_labels = len(processor.get_labels()) num_labels = len(processor.get_labels())
if not (args.do_train or args.do_val or args.do_infer): if not (args.do_train or args.do_val or args.do_infer):
...@@ -180,9 +212,7 @@ def main(args): ...@@ -180,9 +212,7 @@ def main(args):
if args.do_train: if args.do_train:
train_data_generator = processor.data_generator( train_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='train', epoch=args.epoch)
phase='train',
epoch=args.epoch)
num_train_examples = processor.get_num_examples(phase="train") num_train_examples = processor.get_num_examples(phase="train")
max_train_steps = args.epoch * num_train_examples // args.batch_size + 1 max_train_steps = args.epoch * num_train_examples // args.batch_size + 1
...@@ -210,7 +240,7 @@ def main(args): ...@@ -210,7 +240,7 @@ def main(args):
lower_mem, upper_mem, unit = fluid.contrib.memory_usage( lower_mem, upper_mem, unit = fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size) program=train_program, batch_size=args.batch_size)
print("Theoretical memory usage in training: %.3f - %.3f %s" % print("Theoretical memory usage in training: %.3f - %.3f %s" %
(lower_mem, upper_mem, unit)) (lower_mem, upper_mem, unit))
if args.do_val: if args.do_val:
test_prog = fluid.Program() test_prog = fluid.Program()
...@@ -241,17 +271,12 @@ def main(args): ...@@ -241,17 +271,12 @@ def main(args):
if args.do_train: if args.do_train:
if args.init_checkpoint: if args.init_checkpoint:
utils.init_checkpoint( utils.init_checkpoint(
exe, exe, args.init_checkpoint, main_program=startup_prog)
args.init_checkpoint,
main_program=startup_prog)
elif args.do_val or args.do_infer: elif args.do_val or args.do_infer:
if not args.init_checkpoint: if not args.init_checkpoint:
raise ValueError("args 'init_checkpoint' should be set if" raise ValueError("args 'init_checkpoint' should be set if"
"only doing validation or infer!") "only doing validation or infer!")
utils.init_checkpoint( utils.init_checkpoint(exe, args.init_checkpoint, main_program=test_prog)
exe,
args.init_checkpoint,
main_program=test_prog)
if args.do_train: if args.do_train:
train_exe = exe train_exe = exe
...@@ -288,22 +313,27 @@ def main(args): ...@@ -288,22 +313,27 @@ def main(args):
total_num_seqs.extend(np_num_seqs) total_num_seqs.extend(np_num_seqs)
if args.verbose: if args.verbose:
verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size() verbose = "train pyreader queue size: %d, " % train_pyreader.queue.size(
)
print(verbose) print(verbose)
time_end = time.time() time_end = time.time()
used_time = time_end - time_begin used_time = time_end - time_begin
print("step: %d, avg loss: %f, " print("step: %d, avg loss: %f, "
"avg acc: %f, speed: %f steps/s" % "avg acc: %f, speed: %f steps/s" %
(steps, np.sum(total_cost) / np.sum(total_num_seqs), (steps, np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs), np.sum(total_acc) / np.sum(total_num_seqs),
args.skip_steps / used_time)) args.skip_steps / used_time))
ce_info.append([np.sum(total_cost) / np.sum(total_num_seqs), np.sum(total_acc) / np.sum(total_num_seqs), used_time]) ce_info.append([
np.sum(total_cost) / np.sum(total_num_seqs),
np.sum(total_acc) / np.sum(total_num_seqs), used_time
])
total_cost, total_acc, total_num_seqs = [], [], [] total_cost, total_acc, total_num_seqs = [], [], []
time_begin = time.time() time_begin = time.time()
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.output_dir, "step_" + str(steps)) save_path = os.path.join(args.output_dir,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.io.save_persistables(exe, save_path, train_program)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
...@@ -315,8 +345,8 @@ def main(args): ...@@ -315,8 +345,8 @@ def main(args):
phase='dev', phase='dev',
epoch=1)) epoch=1))
evaluate(test_exe, test_prog, test_pyreader, evaluate(test_exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name],
"dev") "dev")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.output_dir, "step_" + str(steps)) save_path = os.path.join(args.output_dir, "step_" + str(steps))
...@@ -336,33 +366,25 @@ def main(args): ...@@ -336,33 +366,25 @@ def main(args):
except: except:
print("ce info error") print("ce info error")
print("kpis\teach_step_duration_%s_card%s\t%s" % print("kpis\teach_step_duration_%s_card%s\t%s" %
(task_name, card_num, ce_time)) (task_name, card_num, ce_time))
print("kpis\ttrain_loss_%s_card%s\t%f" % print("kpis\ttrain_loss_%s_card%s\t%f" % (task_name, card_num, ce_loss))
(task_name, card_num, ce_loss)) print("kpis\ttrain_acc_%s_card%s\t%f" % (task_name, card_num, ce_acc))
print("kpis\ttrain_acc_%s_card%s\t%f" %
(task_name, card_num, ce_acc))
# evaluate on test set # evaluate on test set
if not args.do_train and args.do_val: if not args.do_train and args.do_val:
test_pyreader.decorate_paddle_reader( test_pyreader.decorate_paddle_reader(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='test', epoch=1))
phase='test',
epoch=1))
print("Final test result:") print("Final test result:")
evaluate(test_exe, test_prog, test_pyreader, evaluate(test_exe, test_prog, test_pyreader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name], "test")
"test")
# infer # infer
if args.do_infer: if args.do_infer:
infer_pyreader.decorate_paddle_reader( infer_pyreader.decorate_paddle_reader(
processor.data_generator( processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='infer', epoch=1))
phase='infer', infer(test_exe, test_prog, infer_pyreader, [probs.name], "infer")
epoch=1))
infer(test_exe, test_prog, infer_pyreader,
[probs.name], "infer")
def get_cards(): def get_cards():
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Emotion Detection Task, based on ERNIE Emotion Detection Task, based on ERNIE
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
EmoTect utilities. EmoTect utilities.
""" """
...@@ -16,6 +29,7 @@ import paddle ...@@ -16,6 +29,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
def str2bool(value): def str2bool(value):
""" """
String to Boolean String to Boolean
...@@ -29,6 +43,7 @@ class ArgumentGroup(object): ...@@ -29,6 +43,7 @@ class ArgumentGroup(object):
""" """
Argument Class Argument Class
""" """
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
...@@ -92,27 +107,33 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch=1): ...@@ -92,27 +107,33 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch=1):
cols = line.strip().split("\t") cols = line.strip().split("\t")
if len(cols) != 1: if len(cols) != 1:
query = cols[-1] query = cols[-1]
wids = [word_dict[x] if x in word_dict else unk_id wids = [
for x in query.strip().split(" ")] word_dict[x] if x in word_dict else unk_id
all_data.append((wids,)) for x in query.strip().split(" ")
]
all_data.append((wids, ))
else: else:
cols = line.strip().split("\t") cols = line.strip().split("\t")
if len(cols) != 2: if len(cols) != 2:
sys.stderr.write("[NOTICE] Error Format Line!") sys.stderr.write("[NOTICE] Error Format Line!")
continue continue
label = int(cols[0]) label = int(cols[0])
wids = [word_dict[x] if x in word_dict else unk_id wids = [
for x in cols[1].split(" ")] word_dict[x] if x in word_dict else unk_id
for x in cols[1].split(" ")
]
all_data.append((wids, label)) all_data.append((wids, label))
num_examples[phrase] = len(all_data) num_examples[phrase] = len(all_data)
if phrase == "infer": if phrase == "infer":
def reader(): def reader():
""" """
Infer reader function Infer reader function
""" """
for wids in all_data: for wids in all_data:
yield wids yield wids
return reader return reader
def reader(): def reader():
...@@ -124,6 +145,7 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch=1): ...@@ -124,6 +145,7 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch=1):
random.shuffle(all_data) random.shuffle(all_data)
for wids, label in all_data: for wids, label in all_data:
yield wids, label yield wids, label
return reader return reader
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
The function lex_net(args) define the lexical analysis network structure The function lex_net(args) define the lexical analysis network structure
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding: utf-8 #coding: utf-8
""" """
The file_reader converts raw corpus to input. The file_reader converts raw corpus to input.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This file is used to train the model. This file is used to train the model.
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding=utf-8 #coding=utf-8
""" """
evaluate wordseg for LAC and other open-source wordseg tools evaluate wordseg for LAC and other open-source wordseg tools
...@@ -20,7 +33,7 @@ def to_unicode(string): ...@@ -20,7 +33,7 @@ def to_unicode(string):
def to_set(words): def to_set(words):
""" cut list to set of (string, off) """ """ cut list to set of (string, off) """
off = 0 off = 0
s= set() s = set()
for w in words: for w in words:
if w: if w:
s.add((off, w)) s.add((off, w))
...@@ -145,7 +158,7 @@ def get_pkuseg_result(sentences): ...@@ -145,7 +158,7 @@ def get_pkuseg_result(sentences):
seg = pkuseg.pkuseg() seg = pkuseg.pkuseg()
preds = [] preds = []
for sentence in sentences: for sentence in sentences:
sent_seg = " ".join(seg.cut(sentence)) sent_seg = " ".join(seg.cut(sentence))
sent_seg = to_unicode(sent_seg) sent_seg = to_unicode(sent_seg)
preds.append(sent_seg) preds.append(sent_seg)
return preds return preds
...@@ -161,7 +174,8 @@ def get_hanlp_result(sentences): ...@@ -161,7 +174,8 @@ def get_hanlp_result(sentences):
preds = [] preds = []
for sentence in sentences: for sentence in sentences:
arraylist = HanLP.segment(sentence) arraylist = HanLP.segment(sentence)
sent_seg = " ".join([term.toString().split("/")[0] for term in arraylist]) sent_seg = " ".join(
[term.toString().split("/")[0] for term in arraylist])
sent_seg = to_unicode(sent_seg) sent_seg = to_unicode(sent_seg)
preds.append(sent_seg) preds.append(sent_seg)
return preds return preds
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
The file_reader converts raw corpus to input. The file_reader converts raw corpus to input.
""" """
...@@ -9,7 +22,10 @@ import glob ...@@ -9,7 +22,10 @@ import glob
def load_kv_dict(dict_path, def load_kv_dict(dict_path,
reverse=False, delimiter="\t", key_func=None, value_func=None): reverse=False,
delimiter="\t",
key_func=None,
value_func=None):
""" """
Load key-value dict from file Load key-value dict from file
""" """
...@@ -34,11 +50,14 @@ def load_kv_dict(dict_path, ...@@ -34,11 +50,14 @@ def load_kv_dict(dict_path,
class Dataset(object): class Dataset(object):
"""data reader""" """data reader"""
def __init__(self, args, mode="train"): def __init__(self, args, mode="train"):
# read dict # read dict
self.word2id_dict = load_kv_dict(args.word_dict_path, reverse=True, value_func=int) self.word2id_dict = load_kv_dict(
args.word_dict_path, reverse=True, value_func=int)
self.id2word_dict = load_kv_dict(args.word_dict_path) self.id2word_dict = load_kv_dict(args.word_dict_path)
self.label2id_dict = load_kv_dict(args.label_dict_path, reverse=True, value_func=int) self.label2id_dict = load_kv_dict(
args.label_dict_path, reverse=True, value_func=int)
self.id2label_dict = load_kv_dict(args.label_dict_path) self.id2label_dict = load_kv_dict(args.label_dict_path)
self.word_replace_dict = load_kv_dict(args.word_rep_dict_path) self.word_replace_dict = load_kv_dict(args.word_rep_dict_path)
...@@ -78,12 +97,12 @@ class Dataset(object): ...@@ -78,12 +97,12 @@ class Dataset(object):
label_ids.append(label_id) label_ids.append(label_id)
return label_ids return label_ids
def file_reader(self, filename, max_seq_len=64, mode="train"): def file_reader(self, filename, max_seq_len=64, mode="train"):
""" """
yield (word_idx, target_idx) one by one from file, yield (word_idx, target_idx) one by one from file,
or yield (word_idx, ) in `infer` mode or yield (word_idx, ) in `infer` mode
""" """
def wrapper(): def wrapper():
fread = io.open(filename, "r", encoding="utf-8") fread = io.open(filename, "r", encoding="utf-8")
headline = next(fread) headline = next(fread)
...@@ -93,9 +112,11 @@ class Dataset(object): ...@@ -93,9 +112,11 @@ class Dataset(object):
for line in fread: for line in fread:
words = line.strip("\n").split("\002") words = line.strip("\n").split("\002")
word_ids = self.word_to_ids(words) word_ids = self.word_to_ids(words)
yield word_ids[0:max_seq_len], [0 for _ in word_ids][0: max_seq_len] yield word_ids[0:max_seq_len], [0 for _ in word_ids][
0:max_seq_len]
else: else:
assert len(headline) == 2 and headline[0] == "text_a" and headline[1] == "label" assert len(headline) == 2 and headline[
0] == "text_a" and headline[1] == "label"
for line in fread: for line in fread:
words, labels = line.strip("\n").split("\t") words, labels = line.strip("\n").split("\t")
word_ids = self.word_to_ids(words.split("\002")) word_ids = self.word_to_ids(words.split("\002"))
...@@ -109,9 +130,21 @@ class Dataset(object): ...@@ -109,9 +130,21 @@ class Dataset(object):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--word_dict_path", type=str, default="./conf/word.dic", help="word dict") parser.add_argument(
parser.add_argument("--label_dict_path", type=str, default="./conf/tag.dic", help="label dict") "--word_dict_path",
parser.add_argument("--word_rep_dict_path", type=str, default="./conf/q2b.dic", help="word replace dict") type=str,
default="./conf/word.dic",
help="word dict")
parser.add_argument(
"--label_dict_path",
type=str,
default="./conf/tag.dic",
help="label dict")
parser.add_argument(
"--word_rep_dict_path",
type=str,
default="./conf/q2b.dic",
help="word replace dict")
args = parser.parse_args() args = parser.parse_args()
dataset = Dataset(args) dataset = Dataset(args)
data_generator = dataset.file_reader("data/train.tsv") data_generator = dataset.file_reader("data/train.tsv")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Sentiment Classification Task Sentiment Classification Task
""" """
...@@ -28,7 +41,6 @@ from models.representation.ernie import ernie_encoder ...@@ -28,7 +41,6 @@ from models.representation.ernie import ernie_encoder
from models.sequence_labeling import nets from models.sequence_labeling import nets
import utils import utils
# yapf: disable # yapf: disable
parser = argparse.ArgumentParser(__doc__) parser = argparse.ArgumentParser(__doc__)
model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.") model_g = utils.ArgumentGroup(parser, "model", "model configuration and paths.")
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Baidu's open-source Lexical Analysis tool for Chinese, including: Baidu's open-source Lexical Analysis tool for Chinese, including:
1. Word Segmentation, 1. Word Segmentation,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
util tools util tools
""" """
...@@ -19,6 +32,7 @@ class ArgumentGroup(object): ...@@ -19,6 +32,7 @@ class ArgumentGroup(object):
""" """
Put arguments to one group Put arguments to one group
""" """
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
"""none""" """none"""
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
...@@ -86,7 +100,7 @@ def parse_result(words, crf_decode, dataset): ...@@ -86,7 +100,7 @@ def parse_result(words, crf_decode, dataset):
sent_len = offset_list[sent_index + 1] - offset_list[sent_index] sent_len = offset_list[sent_index + 1] - offset_list[sent_index]
last_word = "" last_word = ""
last_tag = "" last_tag = ""
for tag_index in range(sent_len): # iterate every word in sent for tag_index in range(sent_len): # iterate every word in sent
index = tag_index + offset_list[sent_index] index = tag_index + offset_list[sent_index]
cur_word_id = str(words[index][0]) cur_word_id = str(words[index][0])
cur_tag_id = str(crf_decode[index][0]) cur_tag_id = str(crf_decode[index][0])
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This module provide nets for text classification This module provide nets for text classification
""" """
import paddle.fluid as fluid import paddle.fluid as fluid
def bow_net(data, def bow_net(data,
label, label,
dict_dim, dict_dim,
...@@ -192,14 +206,14 @@ def gru_net(data, ...@@ -192,14 +206,14 @@ def gru_net(data,
def textcnn_net(data, def textcnn_net(data,
label, label,
dict_dim, dict_dim,
emb_dim=128, emb_dim=128,
hid_dim=128, hid_dim=128,
hid_dim2=96, hid_dim2=96,
class_dim=2, class_dim=2,
win_sizes=None, win_sizes=None,
is_infer=False): is_infer=False):
""" """
Textcnn_net Textcnn_net
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
bow class bow class
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
cnn class cnn class
""" """
...@@ -30,8 +43,8 @@ class CNN(object): ...@@ -30,8 +43,8 @@ class CNN(object):
left_emb = emb_layer.ops(left) left_emb = emb_layer.ops(left)
right_emb = emb_layer.ops(right) right_emb = emb_layer.ops(right)
# Presentation context # Presentation context
cnn_layer = layers.SequenceConvPoolLayer( cnn_layer = layers.SequenceConvPoolLayer(self.filter_size,
self.filter_size, self.num_filters, "conv") self.num_filters, "conv")
left_cnn = cnn_layer.ops(left_emb) left_cnn = cnn_layer.ops(left_emb)
right_cnn = cnn_layer.ops(right_emb) right_cnn = cnn_layer.ops(right_emb)
# matching layer # matching layer
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
gru class gru class
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
hinge loss hinge loss
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
log loss log loss
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
softmax loss softmax loss
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
lstm class lstm class
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
MMDNN class MMDNN class
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
optimizer calss optimizer calss
""" """
...@@ -43,5 +56,8 @@ class AdamOptimizer(object): ...@@ -43,5 +56,8 @@ class AdamOptimizer(object):
Adam optimizer operation Adam optimizer operation
""" """
adam = fluid.optimizer.AdamOptimizer( adam = fluid.optimizer.AdamOptimizer(
self.learning_rate, beta1=self.beta1, beta2=self.beta2, epsilon=self.epsilon) self.learning_rate,
beta1=self.beta1,
beta2=self.beta2,
epsilon=self.epsilon)
adam.minimize(loss) adam.minimize(loss)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
network layers network layers
""" """
...@@ -23,9 +36,11 @@ class EmbeddingLayer(object): ...@@ -23,9 +36,11 @@ class EmbeddingLayer(object):
""" """
operation operation
""" """
emb = fluid.layers.embedding(input=input, size=[ emb = fluid.layers.embedding(
self.dict_size, self.emb_dim], is_sparse=True, input=input,
param_attr=attr.ParamAttr(name=self.name)) size=[self.dict_size, self.emb_dim],
is_sparse=True,
param_attr=attr.ParamAttr(name=self.name))
return emb return emb
...@@ -44,8 +59,7 @@ class SequencePoolLayer(object): ...@@ -44,8 +59,7 @@ class SequencePoolLayer(object):
""" """
operation operation
""" """
pool = fluid.layers.sequence_pool( pool = fluid.layers.sequence_pool(input=input, pool_type=self.pool_type)
input=input, pool_type=self.pool_type)
return pool return pool
...@@ -66,9 +80,12 @@ class FCLayer(object): ...@@ -66,9 +80,12 @@ class FCLayer(object):
""" """
operation operation
""" """
fc = fluid.layers.fc(input=input, size=self.fc_dim, param_attr=attr.ParamAttr( fc = fluid.layers.fc(input=input,
name="%s.w" % self.name), size=self.fc_dim,
bias_attr=attr.ParamAttr(name="%s.b" % self.name), act=self.act, name=self.name) param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name),
act=self.act,
name=self.name)
return fc return fc
...@@ -88,12 +105,16 @@ class DynamicGRULayer(object): ...@@ -88,12 +105,16 @@ class DynamicGRULayer(object):
""" """
operation operation
""" """
proj = fluid.layers.fc(input=input, size=self.gru_dim * 3, proj = fluid.layers.fc(
param_attr=attr.ParamAttr(name="%s_fc.w" % self.name), input=input,
bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name)) size=self.gru_dim * 3,
gru = fluid.layers.dynamic_gru(input=proj, size=self.gru_dim, param_attr=attr.ParamAttr(name="%s_fc.w" % self.name),
param_attr=attr.ParamAttr(name="%s.w" % self.name), bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name))
bias_attr=attr.ParamAttr(name="%s.b" % self.name)) gru = fluid.layers.dynamic_gru(
input=proj,
size=self.gru_dim,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name))
return gru return gru
...@@ -113,12 +134,16 @@ class DynamicLSTMLayer(object): ...@@ -113,12 +134,16 @@ class DynamicLSTMLayer(object):
""" """
operation operation
""" """
proj = fluid.layers.fc(input=input, size=self.lstm_dim * 4, proj = fluid.layers.fc(
param_attr=attr.ParamAttr(name="%s_fc.w" % self.name), input=input,
bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name)) size=self.lstm_dim * 4,
lstm, _ = fluid.layers.dynamic_lstm(input=proj, size=self.lstm_dim * 4, param_attr=attr.ParamAttr(name="%s_fc.w" % self.name),
param_attr=attr.ParamAttr(name="%s.w" % self.name), bias_attr=attr.ParamAttr(name="%s_fc.b" % self.name))
bias_attr=attr.ParamAttr(name="%s.b" % self.name)) lstm, _ = fluid.layers.dynamic_lstm(
input=proj,
size=self.lstm_dim * 4,
param_attr=attr.ParamAttr(name="%s.w" % self.name),
bias_attr=attr.ParamAttr(name="%s.b" % self.name))
return lstm return lstm
...@@ -161,9 +186,12 @@ class SequenceConvPoolLayer(object): ...@@ -161,9 +186,12 @@ class SequenceConvPoolLayer(object):
""" """
operation operation
""" """
conv = fluid.nets.sequence_conv_pool(input=input, filter_size=self.filter_size, conv = fluid.nets.sequence_conv_pool(
num_filters=self.num_filters, input=input,
param_attr=attr.ParamAttr(name=self.name), act="relu") filter_size=self.filter_size,
num_filters=self.num_filters,
param_attr=attr.ParamAttr(name=self.name),
act="relu")
return conv return conv
...@@ -259,7 +287,8 @@ class SoftmaxWithCrossEntropyLayer(object): ...@@ -259,7 +287,8 @@ class SoftmaxWithCrossEntropyLayer(object):
""" """
operation operation
""" """
loss = fluid.layers.softmax_with_cross_entropy(logits=input, label=label) loss = fluid.layers.softmax_with_cross_entropy(
logits=input, label=label)
return loss return loss
...@@ -354,8 +383,8 @@ class ConstantLayer(object): ...@@ -354,8 +383,8 @@ class ConstantLayer(object):
""" """
operation operation
""" """
constant = fluid.layers.fill_constant_batch_size_like( constant = fluid.layers.fill_constant_batch_size_like(input, shape,
input, shape, dtype, value) dtype, value)
return constant return constant
...@@ -396,6 +425,7 @@ class SoftsignLayer(object): ...@@ -396,6 +425,7 @@ class SoftsignLayer(object):
softsign = fluid.layers.softsign(input) softsign = fluid.layers.softsign(input)
return softsign return softsign
# class MatmulLayer(object): # class MatmulLayer(object):
# def __init__(self, transpose_x, transpose_y): # def __init__(self, transpose_x, transpose_y):
# self.transpose_x = transpose_x # self.transpose_x = transpose_x
...@@ -405,7 +435,6 @@ class SoftsignLayer(object): ...@@ -405,7 +435,6 @@ class SoftsignLayer(object):
# matmul = fluid.layers.matmul(x, y, self.transpose_x, self.transpose_y) # matmul = fluid.layers.matmul(x, y, self.transpose_x, self.transpose_y)
# return matmul # return matmul
# class Conv2dLayer(object): # class Conv2dLayer(object):
# def __init__(self, num_filters, filter_size, act, name): # def __init__(self, num_filters, filter_size, act, name):
# self.num_filters = num_filters # self.num_filters = num_filters
...@@ -417,7 +446,6 @@ class SoftsignLayer(object): ...@@ -417,7 +446,6 @@ class SoftsignLayer(object):
# conv = fluid.layers.conv2d(input, self.num_filters, self.filter_size, param_attr=attr.ParamAttr(name="%s.w" % self.name), bias_attr=attr.ParamAttr(name="%s.b" % self.name), act=self.act) # conv = fluid.layers.conv2d(input, self.num_filters, self.filter_size, param_attr=attr.ParamAttr(name="%s.w" % self.name), bias_attr=attr.ParamAttr(name="%s.b" % self.name), act=self.act)
# return conv # return conv
# class Pool2dLayer(object): # class Pool2dLayer(object):
# def __init__(self, pool_size, pool_type): # def __init__(self, pool_size, pool_type):
# self.pool_size = pool_size # self.pool_size = pool_size
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# The placeholder for batch_size in compile time. Must be -1 currently to be # The placeholder for batch_size in compile time. Must be -1 currently to be
# consistent with some ops' infer-shape output in compile time, such as the # consistent with some ops' infer-shape output in compile time, such as the
# sequence_expand op used in beamsearch decoder. # sequence_expand op used in beamsearch decoder.
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial from functools import partial
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This module provides ErnieModel and ErnieConfig This module provides ErnieModel and ErnieConfig
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
The function lex_net(args) define the lexical analysis network structure The function lex_net(args) define the lexical analysis network structure
""" """
...@@ -96,8 +109,7 @@ def lex_net(word, target, args, vocab_size, num_labels): ...@@ -96,8 +109,7 @@ def lex_net(word, target, args, vocab_size, num_labels):
input=emission, input=emission,
label=target, label=target,
param_attr=fluid.ParamAttr( param_attr=fluid.ParamAttr(
name='crfw', name='crfw', learning_rate=crf_lr))
learning_rate=crf_lr))
crf_decode = fluid.layers.crf_decoding( crf_decode = fluid.layers.crf_decoding(
input=emission, param_attr=fluid.ParamAttr(name='crfw')) input=emission, param_attr=fluid.ParamAttr(name='crfw'))
avg_cost = fluid.layers.mean(x=crf_cost) avg_cost = fluid.layers.mean(x=crf_cost)
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoder.""" """Transformer encoder."""
from __future__ import absolute_import from __future__ import absolute_import
...@@ -100,7 +113,7 @@ def multi_head_attention(queries, ...@@ -100,7 +113,7 @@ def multi_head_attention(queries,
""" """
Scaled Dot-Product Attention Scaled Dot-Product Attention
""" """
scaled_q = layers.scale(x=q, scale=d_key ** -0.5) scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True) product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias: if attn_bias:
product += attn_bias product += attn_bias
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This module provides reader for classification and sequence labing This module provides reader for classification and sequence labing
""" """
...@@ -18,6 +31,7 @@ from preprocess.padding import pad_batch_data ...@@ -18,6 +31,7 @@ from preprocess.padding import pad_batch_data
class BaseReader(object): class BaseReader(object):
"""BaseReader for classify and sequence labeling task""" """BaseReader for classify and sequence labeling task"""
def __init__(self, def __init__(self,
vocab_path, vocab_path,
label_map_config=None, label_map_config=None,
...@@ -211,6 +225,7 @@ class BaseReader(object): ...@@ -211,6 +225,7 @@ class BaseReader(object):
class ClassifyReader(BaseReader): class ClassifyReader(BaseReader):
"""ClassifyReader""" """ClassifyReader"""
def _read_tsv(self, input_file, quotechar=None): def _read_tsv(self, input_file, quotechar=None):
"""Reads a tab separated value file.""" """Reads a tab separated value file."""
with open(input_file, "r") as f: with open(input_file, "r") as f:
...@@ -239,7 +254,10 @@ class ClassifyReader(BaseReader): ...@@ -239,7 +254,10 @@ class ClassifyReader(BaseReader):
# padding # padding
padded_token_ids, input_mask, seq_lens = pad_batch_data( padded_token_ids, input_mask, seq_lens = pad_batch_data(
batch_token_ids, pad_idx=self.pad_id, return_input_mask=True, return_seq_lens=True) batch_token_ids,
pad_idx=self.pad_id,
return_input_mask=True,
return_seq_lens=True)
padded_text_type_ids = pad_batch_data( padded_text_type_ids = pad_batch_data(
batch_text_type_ids, pad_idx=self.pad_id) batch_text_type_ids, pad_idx=self.pad_id)
padded_position_ids = pad_batch_data( padded_position_ids = pad_batch_data(
...@@ -255,6 +273,7 @@ class ClassifyReader(BaseReader): ...@@ -255,6 +273,7 @@ class ClassifyReader(BaseReader):
class SequenceLabelReader(BaseReader): class SequenceLabelReader(BaseReader):
"""SequenceLabelReader""" """SequenceLabelReader"""
def _pad_batch_records(self, batch_records): def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
...@@ -314,7 +333,9 @@ class SequenceLabelReader(BaseReader): ...@@ -314,7 +333,9 @@ class SequenceLabelReader(BaseReader):
position_ids = list(range(len(token_ids))) position_ids = list(range(len(token_ids)))
text_type_ids = [0] * len(token_ids) text_type_ids = [0] * len(token_ids)
no_entity_id = len(self.label_map) - 1 no_entity_id = len(self.label_map) - 1
labels = [label if label in self.label_map else u"O" for label in labels] labels = [
label if label in self.label_map else u"O" for label in labels
]
label_ids = [no_entity_id] + [ label_ids = [no_entity_id] + [
self.label_map[label] for label in labels self.label_map[label] for label in labels
] + [no_entity_id] ] + [no_entity_id]
...@@ -332,6 +353,7 @@ class SequenceLabelReader(BaseReader): ...@@ -332,6 +353,7 @@ class SequenceLabelReader(BaseReader):
class ExtractEmbeddingReader(BaseReader): class ExtractEmbeddingReader(BaseReader):
"""ExtractEmbeddingReader""" """ExtractEmbeddingReader"""
def _pad_batch_records(self, batch_records): def _pad_batch_records(self, batch_records):
batch_token_ids = [record.token_ids for record in batch_records] batch_token_ids = [record.token_ids for record in batch_records]
batch_text_type_ids = [record.text_type_ids for record in batch_records] batch_text_type_ids = [record.text_type_ids for record in batch_records]
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Mask, padding and batching. Mask, padding and batching.
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
The file_reader converts raw corpus to input. The file_reader converts raw corpus to input.
""" """
...@@ -5,6 +18,7 @@ import os ...@@ -5,6 +18,7 @@ import os
import __future__ import __future__
import io import io
def file_reader(file_dir, def file_reader(file_dir,
word2id_dict, word2id_dict,
label2id_dict, label2id_dict,
...@@ -15,6 +29,7 @@ def file_reader(file_dir, ...@@ -15,6 +29,7 @@ def file_reader(file_dir,
""" """
word_dict_len = max(map(int, word2id_dict.values())) + 1 word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1 label_dict_len = max(map(int, label2id_dict.values())) + 1
def reader(): def reader():
""" """
the data generator the data generator
...@@ -24,7 +39,8 @@ def file_reader(file_dir, ...@@ -24,7 +39,8 @@ def file_reader(file_dir,
for filename in files: for filename in files:
if not filename.startswith(filename_feature): if not filename.startswith(filename_feature):
continue continue
for line in io.open(os.path.join(root, filename), 'r', encoding='utf8'): for line in io.open(
os.path.join(root, filename), 'r', encoding='utf8'):
index += 1 index += 1
bad_line = False bad_line = False
line = line.strip("\n") line = line.strip("\n")
...@@ -52,8 +68,9 @@ def file_reader(file_dir, ...@@ -52,8 +68,9 @@ def file_reader(file_dir,
else: else:
target_idx.append(int(label2id_dict["O"])) target_idx.append(int(label2id_dict["O"]))
if len(word_idx) != len(target_idx): if len(word_idx) != len(target_idx):
continue continue
yield word_idx, target_idx yield word_idx, target_idx
return reader return reader
...@@ -68,6 +85,7 @@ def test_reader(file_dir, ...@@ -68,6 +85,7 @@ def test_reader(file_dir,
#print (word2id_dict) #print (word2id_dict)
word_dict_len = max(map(int, word2id_dict.values())) + 1 word_dict_len = max(map(int, word2id_dict.values())) + 1
label_dict_len = max(map(int, label2id_dict.values())) + 1 label_dict_len = max(map(int, label2id_dict.values())) + 1
#print word_dict_len #print word_dict_len
#print label_dict_len #print label_dict_len
def reader(): def reader():
...@@ -94,6 +112,7 @@ def test_reader(file_dir, ...@@ -94,6 +112,7 @@ def test_reader(file_dir,
else: else:
word_idx.append(int(word2id_dict["OOV"])) word_idx.append(int(word2id_dict["OOV"]))
yield word_idx, words yield word_idx, words
return reader return reader
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
This module provides wordseg tools This module provides wordseg tools
""" """
...@@ -11,12 +24,13 @@ import time ...@@ -11,12 +24,13 @@ import time
import sys import sys
import io import io
if sys.version_info > (3,): if sys.version_info > (3, ):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8')
else: else:
reload(sys) reload(sys)
sys.setdefaultencoding("utf8") sys.setdefaultencoding("utf8")
def parse_args(): def parse_args():
""" """
Arguments Parse Arguments Parse
...@@ -26,32 +40,27 @@ def parse_args(): ...@@ -26,32 +40,27 @@ def parse_args():
'--batch_size', '--batch_size',
type=int, type=int,
default=5, default=5,
help='The size of a batch. (default: %(default)d)' help='The size of a batch. (default: %(default)d)')
)
parser.add_argument( parser.add_argument(
'--model_path', '--model_path',
type=str, type=str,
default='./conf/model', default='./conf/model',
help='A path to the model. (default: %(default)s)' help='A path to the model. (default: %(default)s)')
)
parser.add_argument( parser.add_argument(
'--test_data_dir', '--test_data_dir',
type=str, type=str,
default='./data/test_data', default='./data/test_data',
help='A directory with test data files. (default: %(default)s)' help='A directory with test data files. (default: %(default)s)')
)
parser.add_argument( parser.add_argument(
"--word_dict_path", "--word_dict_path",
type=str, type=str,
default="./conf/word.dic", default="./conf/word.dic",
help="The path of the word dictionary. (default: %(default)s)" help="The path of the word dictionary. (default: %(default)s)")
)
parser.add_argument( parser.add_argument(
"--label_dict_path", "--label_dict_path",
type=str, type=str,
default="./conf/tag.dic", default="./conf/tag.dic",
help="The path of the label dictionary. (default: %(default)s)" help="The path of the label dictionary. (default: %(default)s)")
)
parser.add_argument( parser.add_argument(
"--word_rep_dict_path", "--word_rep_dict_path",
type=str, type=str,
...@@ -104,17 +113,15 @@ def infer(args): ...@@ -104,17 +113,15 @@ def infer(args):
Tokenize Tokenize
""" """
id2word_dict = reader.load_dict(args.word_dict_path) id2word_dict = reader.load_dict(args.word_dict_path)
word2id_dict = reader.load_reverse_dict(args.word_dict_path) word2id_dict = reader.load_reverse_dict(args.word_dict_path)
id2label_dict = reader.load_dict(args.label_dict_path) id2label_dict = reader.load_dict(args.label_dict_path)
label2id_dict = reader.load_reverse_dict(args.label_dict_path) label2id_dict = reader.load_reverse_dict(args.label_dict_path)
q2b_dict = reader.load_dict(args.word_rep_dict_path) q2b_dict = reader.load_dict(args.word_rep_dict_path)
test_data = paddle.batch( test_data = paddle.batch(
reader.test_reader(args.test_data_dir, reader.test_reader(args.test_data_dir, word2id_dict, label2id_dict,
word2id_dict, q2b_dict),
label2id_dict, batch_size=args.batch_size)
q2b_dict),
batch_size = args.batch_size)
place = fluid.CPUPlace() place = fluid.CPUPlace()
#place = fluid.CUDAPlace(0) #place = fluid.CUDAPlace(0)
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -130,9 +137,9 @@ def infer(args): ...@@ -130,9 +137,9 @@ def infer(args):
#print(word_idx) #print(word_idx)
word_list = [x[1] for x in data] word_list = [x[1] for x in data]
(crf_decode, ) = exe.run(inference_program, (crf_decode, ) = exe.run(inference_program,
feed={"word":word_idx}, feed={"word": word_idx},
fetch_list=fetch_targets, fetch_list=fetch_targets,
return_numpy=False) return_numpy=False)
lod_info = (crf_decode.lod())[0] lod_info = (crf_decode.lod())[0]
np_data = np.array(crf_decode) np_data = np.array(crf_decode)
assert len(data) == len(lod_info) - 1 assert len(data) == len(lod_info) - 1
...@@ -145,7 +152,7 @@ def infer(args): ...@@ -145,7 +152,7 @@ def infer(args):
cur_full_tag = "" cur_full_tag = ""
words = word_list[sen_index] words = word_list[sen_index]
for tag_index in range(lod_info[sen_index], for tag_index in range(lod_info[sen_index],
lod_info[sen_index + 1]): lod_info[sen_index + 1]):
cur_word = words[word_index] cur_word = words[word_index]
cur_tag = id2label_dict[str(np_data[tag_index][0])] cur_tag = id2label_dict[str(np_data[tag_index][0])]
if cur_tag.endswith("-B") or cur_tag.endswith("O"): if cur_tag.endswith("-B") or cur_tag.endswith("O"):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Senta model. Senta model.
""" """
...@@ -11,10 +24,12 @@ import json ...@@ -11,10 +24,12 @@ import json
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
class SentaConfig(object): class SentaConfig(object):
""" """
Senta Config Senta Config
""" """
def __init__(self, config_path): def __init__(self, config_path):
self._config_dict = self._parse(config_path) self._config_dict = self._parse(config_path)
...@@ -24,7 +39,7 @@ class SentaConfig(object): ...@@ -24,7 +39,7 @@ class SentaConfig(object):
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing bert model config file '%s'" % raise IOError("Error in parsing bert model config file '%s'" %
config_path) config_path)
else: else:
return config_dict return config_dict
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Senta Reader Senta Reader
""" """
...@@ -12,15 +25,13 @@ from utils import data_reader ...@@ -12,15 +25,13 @@ from utils import data_reader
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
class SentaProcessor(object): class SentaProcessor(object):
""" """
Processor class for data convertors for senta Processor class for data convertors for senta
""" """
def __init__(self, def __init__(self, data_dir, vocab_path, random_seed=None):
data_dir,
vocab_path,
random_seed=None):
self.data_dir = data_dir self.data_dir = data_dir
self.vocab = load_vocab(vocab_path) self.vocab = load_vocab(vocab_path)
self.num_examples = {"train": -1, "dev": -1, "infer": -1} self.num_examples = {"train": -1, "dev": -1, "infer": -1}
...@@ -30,19 +41,22 @@ class SentaProcessor(object): ...@@ -30,19 +41,22 @@ class SentaProcessor(object):
""" """
Load training examples Load training examples
""" """
return data_reader((self.data_dir + "/train.tsv"), self.vocab, self.num_examples, "train", epoch) return data_reader((self.data_dir + "/train.tsv"), self.vocab,
self.num_examples, "train", epoch)
def get_dev_examples(self, data_dir, epoch): def get_dev_examples(self, data_dir, epoch):
""" """
Load dev examples Load dev examples
""" """
return data_reader((self.data_dir + "/dev.tsv"), self.vocab, self.num_examples, "dev", epoch) return data_reader((self.data_dir + "/dev.tsv"), self.vocab,
self.num_examples, "dev", epoch)
def get_test_examples(self, data_dir, epoch): def get_test_examples(self, data_dir, epoch):
""" """
Load test examples Load test examples
""" """
return data_reader((self.data_dir + "/test.tsv"), self.vocab, self.num_examples, "infer", epoch) return data_reader((self.data_dir + "/test.tsv"), self.vocab,
self.num_examples, "infer", epoch)
def get_labels(self): def get_labels(self):
""" """
...@@ -70,11 +84,14 @@ class SentaProcessor(object): ...@@ -70,11 +84,14 @@ class SentaProcessor(object):
Generate data for train, dev or infer Generate data for train, dev or infer
""" """
if phase == "train": if phase == "train":
return paddle.batch(self.get_train_examples(self.data_dir, epoch), batch_size) return paddle.batch(
self.get_train_examples(self.data_dir, epoch), batch_size)
elif phase == "dev": elif phase == "dev":
return paddle.batch(self.get_dev_examples(self.data_dir, epoch), batch_size) return paddle.batch(
self.get_dev_examples(self.data_dir, epoch), batch_size)
elif phase == "infer": elif phase == "infer":
return paddle.batch(self.get_test_examples(self.data_dir, epoch), batch_size) return paddle.batch(
self.get_test_examples(self.data_dir, epoch), batch_size)
else: else:
raise ValueError( raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'infer'].") "Unknown phase, which should be in ['train', 'dev', 'infer'].")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Sentiment Classification Task Sentiment Classification Task
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Sentiment Classification Task Sentiment Classification Task
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Arguments for configuration Arguments for configuration
""" """
...@@ -31,6 +44,7 @@ class ArgumentGroup(object): ...@@ -31,6 +44,7 @@ class ArgumentGroup(object):
""" """
Argument Class Argument Class
""" """
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
...@@ -79,7 +93,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -79,7 +93,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
def data_reader(file_path, word_dict, num_examples, phrase, epoch): def data_reader(file_path, word_dict, num_examples, phrase, epoch):
""" """
Convert word sequence into slot Convert word sequence into slot
...@@ -95,15 +109,17 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch): ...@@ -95,15 +109,17 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch):
sys.stderr.write("[NOTICE] Error Format Line!") sys.stderr.write("[NOTICE] Error Format Line!")
continue continue
label = int(cols[1]) label = int(cols[1])
wids = [word_dict[x] if x in word_dict else unk_id wids = [
for x in cols[0].split(" ")] word_dict[x] if x in word_dict else unk_id
for x in cols[0].split(" ")
]
all_data.append((wids, label)) all_data.append((wids, label))
if phrase == "train": if phrase == "train":
random.shuffle(all_data) random.shuffle(all_data)
num_examples[phrase] = len(all_data) num_examples[phrase] = len(all_data)
def reader(): def reader():
""" """
Reader Function Reader Function
...@@ -111,8 +127,10 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch): ...@@ -111,8 +127,10 @@ def data_reader(file_path, word_dict, num_examples, phrase, epoch):
for epoch_index in range(epoch): for epoch_index in range(epoch):
for doc, label in all_data: for doc, label in all_data:
yield doc, label yield doc, label
return reader return reader
def load_vocab(file_path): def load_vocab(file_path):
""" """
load the given vocabulary load the given vocabulary
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
SimNet config SimNet config
""" """
...@@ -21,12 +34,14 @@ class SimNetConfig(object): ...@@ -21,12 +34,14 @@ class SimNetConfig(object):
with open(config_path) as json_file: with open(config_path) as json_file:
config_dict = json.load(json_file) config_dict = json.load(json_file)
except Exception: except Exception:
raise IOError("Error in parsing simnet model config file '%s'" % config_path) raise IOError("Error in parsing simnet model config file '%s'" %
config_path)
else: else:
if config_dict["task_mode"] != self.task_mode: if config_dict["task_mode"] != self.task_mode:
raise ValueError( raise ValueError(
"the config '{}' does not match the task_mode '{}'".format(self.config_path, self.task_mode)) "the config '{}' does not match the task_mode '{}'".format(
self.config_path, self.task_mode))
return config_dict return config_dict
def __getitem__(self, key): def __getitem__(self, key):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
comput unicom comput unicom
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
split unicom file split unicom file
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
SimNet reader SimNet reader
""" """
...@@ -25,15 +38,24 @@ class SimNetProcessor(object): ...@@ -25,15 +38,24 @@ class SimNetProcessor(object):
Reader with Pairwise Reader with Pairwise
""" """
if mode == "valid": if mode == "valid":
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file: with codecs.open(self.args.valid_data_dir, "r",
"utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning("line not match format in test file") label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(title) == 0: if len(title) == 0:
...@@ -43,27 +65,47 @@ class SimNetProcessor(object): ...@@ -43,27 +65,47 @@ class SimNetProcessor(object):
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file: with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning("line not match format in test file") label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(title) == 0: if len(title) == 0:
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with codecs.open(self.args.train_data_dir, "r", "utf-8") as file: with codecs.open(self.args.train_data_dir, "r",
"utf-8") as file:
for line in file: for line in file:
query, pos_title, neg_title = line.strip().split("\t") query, pos_title, neg_title = line.strip().split("\t")
if len(query) == 0 or len(pos_title) == 0 or len(neg_title) == 0: if len(query) == 0 or len(pos_title) == 0 or len(
logging.warning("line not match format in test file") neg_title) == 0:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
pos_title = [self.vocab[word] for word in pos_title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
neg_title = [self.vocab[word] for word in neg_title.split(" ") if word in self.vocab] if word in self.vocab
]
pos_title = [
self.vocab[word] for word in pos_title.split(" ")
if word in self.vocab
]
neg_title = [
self.vocab[word] for word in neg_title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(pos_title) == 0: if len(pos_title) == 0:
...@@ -77,15 +119,24 @@ class SimNetProcessor(object): ...@@ -77,15 +119,24 @@ class SimNetProcessor(object):
Reader with Pointwise Reader with Pointwise
""" """
if mode == "valid": if mode == "valid":
with codecs.open(self.args.valid_data_dir, "r", "utf-8") as file: with codecs.open(self.args.valid_data_dir, "r",
"utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning("line not match format in test file") label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(title) == 0: if len(title) == 0:
...@@ -95,27 +146,44 @@ class SimNetProcessor(object): ...@@ -95,27 +146,44 @@ class SimNetProcessor(object):
with codecs.open(self.args.test_data_dir, "r", "utf-8") as file: with codecs.open(self.args.test_data_dir, "r", "utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning("line not match format in test file") label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(title) == 0: if len(title) == 0:
title = [0] title = [0]
yield [query, title] yield [query, title]
else: else:
with codecs.open(self.args.train_data_dir, "r", "utf-8") as file: with codecs.open(self.args.train_data_dir, "r",
"utf-8") as file:
for line in file: for line in file:
query, title, label = line.strip().split("\t") query, title, label = line.strip().split("\t")
if len(query) == 0 or len(title) == 0 or len(label) == 0 or not label.isdigit() or int( if len(query) == 0 or len(title) == 0 or len(
label) not in [0, 1]: label) == 0 or not label.isdigit() or int(
logging.warning("line not match format in test file") label) not in [0, 1]:
logging.warning(
"line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
label = int(label) label = int(label)
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
...@@ -138,8 +206,14 @@ class SimNetProcessor(object): ...@@ -138,8 +206,14 @@ class SimNetProcessor(object):
if len(query) == 0 or len(title) == 0: if len(query) == 0 or len(title) == 0:
logging.warning("line not match format in test file") logging.warning("line not match format in test file")
continue continue
query = [self.vocab[word] for word in query.split(" ") if word in self.vocab] query = [
title = [self.vocab[word] for word in title.split(" ") if word in self.vocab] self.vocab[word] for word in query.split(" ")
if word in self.vocab
]
title = [
self.vocab[word] for word in title.split(" ")
if word in self.vocab
]
if len(query) == 0: if len(query) == 0:
query = [0] query = [0]
if len(title) == 0: if len(title) == 0:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
SimNet Task SimNet Task
""" """
...@@ -164,16 +177,16 @@ def train(conf_dict, args): ...@@ -164,16 +177,16 @@ def train(conf_dict, args):
infer_program = fluid.default_main_program().clone(for_test=True) infer_program = fluid.default_main_program().clone(for_test=True)
avg_cost = loss.compute(pred, label) avg_cost = loss.compute(pred, label)
avg_cost.persistable = True avg_cost.persistable = True
# operate Optimization # operate Optimization
optimizer.ops(avg_cost) optimizer.ops(avg_cost)
executor = fluid.Executor(place) executor = fluid.Executor(place)
executor.run(fluid.default_startup_program()) executor.run(fluid.default_startup_program())
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
utils.init_checkpoint(executor, args.init_checkpoint, utils.init_checkpoint(executor, args.init_checkpoint,
fluid.default_startup_program()) fluid.default_startup_program())
# Get and run executor # Get and run executor
parallel_executor = fluid.ParallelExecutor( parallel_executor = fluid.ParallelExecutor(
use_cuda=args.use_cuda, use_cuda=args.use_cuda,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --coding=utf-8 # --coding=utf-8
""" """
SimNet utilities. SimNet utilities.
...@@ -17,6 +30,7 @@ import paddle.fluid as fluid ...@@ -17,6 +30,7 @@ import paddle.fluid as fluid
******functions for file processing****** ******functions for file processing******
""" """
def load_vocab(file_path): def load_vocab(file_path):
""" """
load the given vocabulary load the given vocabulary
...@@ -47,7 +61,8 @@ def get_result_file(args): ...@@ -47,7 +61,8 @@ def get_result_file(args):
""" """
with codecs.open(args.test_data_dir, "r", "utf-8") as test_file: with codecs.open(args.test_data_dir, "r", "utf-8") as test_file:
with codecs.open("predictions.txt", "r", "utf-8") as predictions_file: with codecs.open("predictions.txt", "r", "utf-8") as predictions_file:
with codecs.open(args.test_result_path, "w", "utf-8") as test_result_file: with codecs.open(args.test_result_path, "w",
"utf-8") as test_result_file:
test_datas = [line.strip("\n") for line in test_file] test_datas = [line.strip("\n") for line in test_file]
predictions = [line.strip("\n") for line in predictions_file] predictions = [line.strip("\n") for line in predictions_file]
for test_data, prediction in zip(test_datas, predictions): for test_data, prediction in zip(test_datas, predictions):
...@@ -287,7 +302,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -287,7 +302,7 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
assert os.path.exists( assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var): def existed_persitables(var):
if not fluid.io.is_persistable(var): if not fluid.io.is_persistable(var):
return False return False
...@@ -299,4 +314,3 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -299,4 +314,3 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
main_program=main_program, main_program=main_program,
predicate=existed_persitables) predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import argparse import argparse
import time import time
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import math import math
import time import time
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six import six
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import six import six
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import six import six
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import six import six
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import six import six
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six import six
import numpy as np import numpy as np
...@@ -190,7 +203,8 @@ def make_one_batch_input(data_batches, index): ...@@ -190,7 +203,8 @@ def make_one_batch_input(data_batches, index):
turns = np.array(data_batches["turns"][index]).astype('int64') turns = np.array(data_batches["turns"][index]).astype('int64')
tt_turns_len = np.array(data_batches["tt_turns_len"][index]).astype('int64') tt_turns_len = np.array(data_batches["tt_turns_len"][index]).astype('int64')
every_turn_len = np.array(data_batches["every_turn_len"][index]).astype('int64') every_turn_len = np.array(data_batches["every_turn_len"][index]).astype(
'int64')
response = np.array(data_batches["response"][index]).astype('int64') response = np.array(data_batches["response"][index]).astype('int64')
response_len = np.array(data_batches["response_len"][index]).astype('int64') response_len = np.array(data_batches["response_len"][index]).astype('int64')
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six import six
import os import os
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import time import time
import math import math
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
import time import time
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import sys import sys
import time import time
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import time import time
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/python #!/usr/bin/python
#-*- coding:utf-8 -*- #-*- coding:utf-8 -*-
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding=utf8 #coding=utf8
import os, sys, json import os, sys, json
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#coding=utf8 #coding=utf8
import sys import sys
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import json import json
import pandas as pd import pandas as pd
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class TrainTaskConfig(object): class TrainTaskConfig(object):
# support both CPU and GPU now. # support both CPU and GPU now.
use_gpu = True use_gpu = True
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import multiprocessing import multiprocessing
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial from functools import partial
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import contextlib import contextlib
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import glob import glob
import six import six
import os import os
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse import argparse
import ast import ast
import copy import copy
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math import math
import paddle.fluid as fluid import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" """
Conll03 dataset. Conll03 dataset.
""" """
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
import logging import logging
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest import unittest
import contextlib import contextlib
import paddle import paddle
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest import unittest
import contextlib import contextlib
import paddle import paddle
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import time import time
import unittest import unittest
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys import sys
import time import time
import numpy as np import numpy as np
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os import os
import six import six
import sys import sys
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册