未验证 提交 7ef02b05 编写于 作者: X XiaoguangHu 提交者: GitHub

Remove legacy Research directory (#4428)

* delete contrib Research and legacy

* add acl2019-arnor

* add readme.md for PaddleKG

* update CoKE
上级 d7fbd956
# CoKE: Contextualized Knowledge Graph Embedding
## Introduction
This is the [PaddlePaddle](https://www.paddlepaddle.org.cn/) implementation of the [CoKE](https://arxiv.org/abs/1911.02168) model for Knowledge Graph Embedding(KGE).
CoKE is a novel KGE paradigm that learns dynamic, flexible, and fully contextualized entity and relation representations for a given Knowledge Graph(KG).
It takes a sequence of entities and relations as input, and uses [Transformer](https://arxiv.org/abs/1706.03762) to obtain contextualized representations for its components.
These representations are hence dynamically adaptive to the input, capturing contextual meanings of entities and relations therein.
Evaluation on a wide variety of public benchmarks verifies the superiority of CoKE in link prediction (also known as Knowledge Graph Completion, or KBC for short) and path query answering tasks.
CoKE performs consistently better than, or at least equally well as current state-of-the-art in almost every case.
## Requirements
The code has been tested running under the following environments:
- Python 3.6.5 with the following dependencies:
- PaddlePaddle 1.5.0
- numpy 1.16.3
- Python 2.7.14 for data_preprocess
- GPU environments:
- CUDA 9.0, CuDNN v7 and NCCL 2.3.7
- GPU: all the datasets run on 1 P40 GPU with our given configurations.
## Model Training and Evaluation
### step1. Download dataset files
Download dataset files used in our paper by running:
```
sh wget_datasets.sh
```
This will first download the 4 widely used KBC datasets ([FB15k&WN18](http://papers.nips.cc/paper/5071-translating-embeddings-for-modeling-multi-relational-data.pdf),
[FB15k-237](https://www.aclweb.org/anthology/W15-4007/),
[WN18RR](https://arxiv.org/abs/1707.01476))
and 2 path query answering datasets ([wordnet_paths and freebase_paths](https://arxiv.org/abs/1506.01094)) .
Then it organize the train/valid/test files as the following `data` directory:
```
data
├── fb15k
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── fb15k237
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
├── pathqueryFB #the original data name is: freebase_paths
│ ├── dev
│ ├── test
│ └── train
├── pathqueryWN #the original data name is: wordnet_paths
│ ├── dev
│ ├── test
│ └── train
├── wn18
│ ├── test.txt
│ ├── train.txt
│ └── valid.txt
└── wn18rr
├── test.txt
├── train.txt
└── valid.txt
```
### step2. Data preprocess
Data preprocess commands are given in `data_preprocess.sh`.
It takes raw train/valid/test files as input, and generates CoKE training and evaluation files.
```
sh data_preprocess.sh
```
### step3. Training
Model training commands are given in `kbc_train.sh` for KBC datasets, and `pathquery_train.sh` for pathquery datasets.
These scripts take a configuration file and GPU-ids as input arguments.
Train the model with a given configuration file.
For example, the following commands train *fb15k* and *pathqueryFB* each with a configuration file:
```
sh kbc_train.sh ./configs/fb15k_job_config.sh 0
sh pathquery_train.sh ./configs/pathqueryFB_job_config.sh 0
```
### step4. Evaluation
Model evaluation commands are given in `kbc_test.sh` for KBC datasets, and `pathquery_test.sh` for pathquery datasets.
These scripts take a configuration file and GPU-ids as input arguments.
For example, the following commands evaluate on *fb15k* and *pathqueryFB*:
```
sh kbc_test.sh ./configs/fb15k_job_config.sh 0
sh pathquery_test.sh ./configs/pathqueryFB_job_config.sh 0
```
We also provide trained model checkpoints on the 4 KBC datasets. Download these models to `kbc_models` directory using the following command:
```
sh wget_kbc_models.sh
```
The `kbc_models` contains the following files:
```
kbc_models
├── fb15k
│   ├── models
│   └── vocab.txt #md5: 0720db5edbda69e00c05441a615db152
├── fb15k237
│   ├── models
│   └── vocab.txt #md5: e843936790e48b3cbb35aa387d0d0fe5
├── wn18
│   ├── models
│   └── vocab.txt #md5: 4904a9300fc3e54aea026ecba7d2c78e
└── wn18rr
├── models
└── vocab.txt #md5: c76aecebf5fc682f0e7922aeba380dd6
```
Check that your preprocessed `vocab.txt` files are identical to ours before evaluation with these models.
## Results
Results on KBC datasets:
|Dataset | MRR | HITS@1 | HITS@5 | HITS@10 |
|---|---|---|---|---|
|FB15K | 0.852 | 0.823 |0.868 | 0.904 |
|FB15K237| 0.361 | 0.269 | 0.398 | 0.547 |
|WN18| 0.951 | 0.947 |0.954 | 0.960|
|WN18RR| 0.475 | 0.437 | 0.490 | 0.552 |
Results on path query datasets:
|Dataset | MQ | HITS@10 |
|---|---|---|
|Freebase | 0.948 | 0.764|
|WordNet |0.942 | 0.674 |
## Reproducing the results
Here are the configs to reproduce our results.
These are also given in the `configs/${TASK}_job_config.sh` files.
| Dataset | NetConfig | lr | softlabel | epoch | batch_size | dropout |
|---|---|---|---|---|---| ---|
|FB15K| L=6, H=256, A=4| 5e-4 | 0.8 | 300 | 512| 0.1 |
|WN18| L=6, H=256, A=4| 5e-4| 0.2 | 500 | 512 | 0.1 |
|FB15K237| L=6, H=256, A=4| 5e-4| 0.25 | 800 | 512 | 0.5 |
|WN18RR| L=6, H=256, A=4|3e-4 | 0.15 | 800 | 1024 | 0.1 |
|pathqueryFB | L=6, H=256, A=4 | 3e-4 | 1 | 10 | 2048 | 0.1 |
|pathqueryWN | L=6, H=256, A=4 | 3e-4 | 1 | 5 | 2048 | 0.1 |
## Citation
If you use any source code included in this project in your work, please cite the following paper:
```
@article{wang2019:coke,
title={CoKE: Contextualized Knowledge Graph Embedding},
author={Wang, Quan and Huang, Pingping and Wang, Haifeng and Dai, Songtai and Jiang, Wenbin and Liu, Jing and Lyu, Yajuan and Wu, Hua},
journal={arXiv:1911.02168},
year={2019}
}
```
## Copyright and License
Copyright 2019 Baidu.com, Inc. All Rights Reserved Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
This work has been moved to new address: [CoKE](https://github.com/PaddlePaddle/Research/tree/master/KG/CoKE)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" evaluation scripts for KBC and pathQuery tasks """
import json
import logging
import collections
import numpy as np
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def kbc_batch_evaluation(eval_i, all_examples, batch_results, tt):
r_hts_idx = collections.defaultdict(list)
scores_head = collections.defaultdict(list)
scores_tail = collections.defaultdict(list)
batch_r_hts_cnt = 0
b_size = len(batch_results)
for j in range(b_size):
result = batch_results[j]
i = eval_i + j
example = all_examples[i]
assert len(example.token_ids
) == 3, "For kbc task each example consists of 3 tokens"
h, r, t = example.token_ids
_mask_type = example.mask_type
if i % 2 == 0:
r_hts_idx[r].append((h, t))
batch_r_hts_cnt += 1
if _mask_type == "MASK_HEAD":
scores_head[(r, t)] = result
elif _mask_type == "MASK_TAIL":
scores_tail[(r, h)] = result
else:
raise ValueError("Unknown mask type in prediction example:%d" % i)
rank = {}
f_rank = {}
for r, hts in r_hts_idx.items():
r_rank = {'head': [], 'tail': []}
r_f_rank = {'head': [], 'tail': []}
for h, t in hts:
scores_t = scores_tail[(r, h)][:]
sortidx_t = np.argsort(scores_t)[::-1]
r_rank['tail'].append(np.where(sortidx_t == t)[0][0] + 1)
rm_idx = tt[r]['ts'][h]
rm_idx = [i for i in rm_idx if i != t]
for i in rm_idx:
scores_t[i] = -np.Inf
sortidx_t = np.argsort(scores_t)[::-1]
r_f_rank['tail'].append(np.where(sortidx_t == t)[0][0] + 1)
scores_h = scores_head[(r, t)][:]
sortidx_h = np.argsort(scores_h)[::-1]
r_rank['head'].append(np.where(sortidx_h == h)[0][0] + 1)
rm_idx = tt[r]['hs'][t]
rm_idx = [i for i in rm_idx if i != h]
for i in rm_idx:
scores_h[i] = -np.Inf
sortidx_h = np.argsort(scores_h)[::-1]
r_f_rank['head'].append(np.where(sortidx_h == h)[0][0] + 1)
rank[r] = r_rank
f_rank[r] = r_f_rank
h_pos = [p for k in rank.keys() for p in rank[k]['head']]
t_pos = [p for k in rank.keys() for p in rank[k]['tail']]
f_h_pos = [p for k in f_rank.keys() for p in f_rank[k]['head']]
f_t_pos = [p for k in f_rank.keys() for p in f_rank[k]['tail']]
ranks = np.asarray(h_pos + t_pos)
f_ranks = np.asarray(f_h_pos + f_t_pos)
return ranks, f_ranks
def pathquery_batch_evaluation(eval_i, all_examples, batch_results,
sen_negli_dict, trivial_sen_set):
""" evaluate the metrics for batch datas for pathquery datasets """
mqs = []
ranks = []
for j, result in enumerate(batch_results):
i = eval_i + j
example = all_examples[i]
token_ids, mask_type = example
assert mask_type in ["MASK_TAIL", "MASK_HEAD"
], " Unknown mask type in pathquery evaluation"
label = token_ids[-1] if mask_type == "MASK_TAIL" else token_ids[0]
sen = " ".join([str(x) for x in token_ids])
if sen in trivial_sen_set:
mq = rank = -1
else:
# candidate vocab set
cand_set = sen_negli_dict[sen]
assert label in set(
cand_set), "predict label must be in the candidate set"
cand_idx = np.sort(np.array(cand_set))
cand_ret = result[
cand_idx] #logits for candidate words(neg + gold words)
cand_ranks = np.argsort(cand_ret)[::-1]
pred_y = cand_idx[cand_ranks]
rank = (np.argwhere(pred_y == label).ravel().tolist())[0] + 1
mq = (len(cand_set) - rank) / (len(cand_set) - 1.0)
mqs.append(mq)
ranks.append(rank)
return mqs, ranks
def compute_kbc_metrics(rank_li, frank_li, output_evaluation_result_file):
""" combine the kbc rank results from batches into the final metrics """
rank_rets = np.array(rank_li).ravel()
frank_rets = np.array(frank_li).ravel()
mrr = np.mean(1.0 / rank_rets)
fmrr = np.mean(1.0 / frank_rets)
hits1 = np.mean(rank_rets <= 1.0)
hits3 = np.mean(rank_rets <= 3.0)
hits10 = np.mean(rank_rets <= 10.0)
# filtered metrics
fhits1 = np.mean(frank_rets <= 1.0)
fhits3 = np.mean(frank_rets <= 3.0)
fhits10 = np.mean(frank_rets <= 10.0)
eval_result = {
'mrr': mrr,
'hits1': hits1,
'hits3': hits3,
'hits10': hits10,
'fmrr': fmrr,
'fhits1': fhits1,
'fhits3': fhits3,
'fhits10': fhits10
}
with open(output_evaluation_result_file, "w") as fw:
fw.write(json.dumps(eval_result, indent=4) + "\n")
return eval_result
def compute_pathquery_metrics(mq_li, rank_li, output_evaluation_result_file):
""" combine the pathquery mq, rank results from batches into the final metrics """
rank_rets = np.array(rank_li).ravel()
_idx = np.where(rank_rets != -1)
non_trivial_eval_rets = rank_rets[_idx]
non_trivial_mq = np.array(mq_li).ravel()[_idx]
non_trivial_cnt = non_trivial_eval_rets.size
mq = np.mean(non_trivial_mq)
mr = np.mean(non_trivial_eval_rets)
mrr = np.mean(1.0 / non_trivial_eval_rets)
fhits10 = np.mean(non_trivial_eval_rets <= 10.0)
eval_result = {
'fcnt': non_trivial_cnt,
'mq': mq,
'mr': mr,
'fhits10': fhits10
}
with open(output_evaluation_result_file, "w") as fw:
fw.write(json.dumps(eval_result, indent=4) + "\n")
return eval_result
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
data preprocess for KBC datasets
"""
import os
import collections
import argparse
def get_unique_entities_relations(train_file, dev_file, test_file):
entity_lst = dict()
relation_lst = dict()
all_files = [train_file, dev_file, test_file]
for input_file in all_files:
print("dealing %s" % train_file)
with open(input_file, "r") as f:
for line in f.readlines():
tokens = line.strip().split("\t")
assert len(tokens) == 3
entity_lst[tokens[0]] = len(entity_lst)
entity_lst[tokens[2]] = len(entity_lst)
relation_lst[tokens[1]] = len(relation_lst)
print(">> Number of unique entities: %s" % len(entity_lst))
print(">> Number of unique relations: %s" % len(relation_lst))
return entity_lst, relation_lst
def write_vocab(output_file, entity_lst, relation_lst):
fout = open(output_file, "w")
fout.write("[PAD]" + "\n")
for i in range(95):
fout.write("[unused{}]\n".format(i))
fout.write("[UNK]" + "\n")
fout.write("[CLS]" + "\n")
fout.write("[SEP]" + "\n")
fout.write("[MASK]" + "\n")
for e in entity_lst.keys():
fout.write(e + "\n")
for r in relation_lst.keys():
fout.write(r + "\n")
vocab_size = 100 + len(entity_lst) + len(relation_lst)
print(">> vocab_size: %s" % vocab_size)
fout.close()
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = line.strip().split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
def write_true_triples(train_file, dev_file, test_file, vocab, output_file):
true_triples = []
all_files = [train_file, dev_file, test_file]
for input_file in all_files:
with open(input_file, "r") as f:
for line in f.readlines():
h, r, t = line.strip('\r \n').split('\t')
assert (h in vocab) and (r in vocab) and (t in vocab)
hpos = vocab[h]
rpos = vocab[r]
tpos = vocab[t]
true_triples.append((hpos, rpos, tpos))
print(">> Number of true triples: %d" % len(true_triples))
fout = open(output_file, "w")
for hpos, rpos, tpos in true_triples:
fout.write(str(hpos) + "\t" + str(rpos) + "\t" + str(tpos) + "\n")
fout.close()
def generate_mask_type(input_file, output_file):
with open(output_file, "w") as fw:
with open(input_file, "r") as fr:
for line in fr.readlines():
fw.write(line.strip('\r \n') + "\tMASK_HEAD\n")
fw.write(line.strip('\r \n') + "\tMASK_TAIL\n")
def kbc_data_preprocess(train_file, dev_file, test_file, vocab_path,
true_triple_path, new_train_file, new_dev_file,
new_test_file):
entity_lst, relation_lst = get_unique_entities_relations(
train_file, dev_file, test_file)
write_vocab(vocab_path, entity_lst, relation_lst)
vocab = load_vocab(vocab_path)
write_true_triples(train_file, dev_file, test_file, vocab,
true_triple_path)
generate_mask_type(train_file, new_train_file)
generate_mask_type(dev_file, new_dev_file)
generate_mask_type(test_file, new_test_file)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
type=str,
required=True,
default=None,
help="task name: fb15k, fb15k237, wn18rr, wn18, pathqueryFB, pathqueryWN"
)
parser.add_argument(
"--dir",
type=str,
required=True,
default=None,
help="task data directory")
parser.add_argument(
"--train",
type=str,
required=False,
default="train.txt",
help="train file name, default train.txt")
parser.add_argument(
"--valid",
type=str,
required=False,
default="valid.txt",
help="valid file name, default valid.txt")
parser.add_argument(
"--test",
type=str,
required=False,
default="test.txt",
help="test file name, default test.txt")
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_args()
task = args.task.lower()
assert task in ["fb15k", "wn18", "fb15k237", "wn18rr"]
raw_train_file = os.path.join(args.dir, args.train)
raw_dev_file = os.path.join(args.dir, args.valid)
raw_test_file = os.path.join(args.dir, args.test)
vocab_file = os.path.join(args.dir, "vocab.txt")
true_triple_file = os.path.join(args.dir, "all.txt")
new_train_file = os.path.join(args.dir, "train.coke.txt")
new_test_file = os.path.join(args.dir, "test.coke.txt")
new_dev_file = os.path.join(args.dir, "valid.coke.txt")
kbc_data_preprocess(raw_train_file, raw_dev_file, raw_test_file,
vocab_file, true_triple_file, new_train_file,
new_dev_file, new_test_file)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""CoKE model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import json
import logging
import numpy as np
import paddle.fluid as fluid
from model.transformer_encoder import encoder, pre_process_layer
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
class CoKEModel(object):
def __init__(self,
src_ids,
position_ids,
input_mask,
config,
soft_label=0.9,
weight_sharing=True,
use_fp16=False):
self._emb_size = config['hidden_size']
self._n_layer = config['num_hidden_layers']
self._n_head = config['num_attention_heads']
self._voc_size = config['vocab_size']
self._n_relation = config['num_relations']
self._max_position_seq_len = config['max_position_embeddings']
self._hidden_act = config['hidden_act']
self._prepostprocess_dropout = config['hidden_dropout_prob']
self._attention_dropout = config['attention_probs_dropout_prob']
self._intermediate_size = config['intermediate_size']
self._soft_label = soft_label
self._weight_sharing = weight_sharing
self._word_emb_name = "word_embedding"
self._pos_emb_name = "pos_embedding"
self._dtype = "float16" if use_fp16 else "float32"
# Initialize all weigths by truncated normal initializer, and all biases
# will be initialized by constant zero by default.
self._param_initializer = fluid.initializer.TruncatedNormal(
scale=config['initializer_range'])
self._build_model(src_ids, position_ids, input_mask)
def _build_model(self, src_ids, position_ids, input_mask):
# padding id in vocabulary must be set to 0
emb_out = fluid.layers.embedding(
input=src_ids,
size=[self._voc_size, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._word_emb_name, initializer=self._param_initializer),
is_sparse=False)
position_emb_out = fluid.layers.embedding(
input=position_ids,
size=[self._max_position_seq_len, self._emb_size],
dtype=self._dtype,
param_attr=fluid.ParamAttr(
name=self._pos_emb_name, initializer=self._param_initializer))
emb_out = emb_out + position_emb_out
emb_out = pre_process_layer(
emb_out, 'nd', self._prepostprocess_dropout, name='pre_encoder')
if self._dtype == "float16":
input_mask = fluid.layers.cast(x=input_mask, dtype=self._dtype)
self_attn_mask = fluid.layers.matmul(
x=input_mask, y=input_mask, transpose_y=True)
self_attn_mask = fluid.layers.scale(
x=self_attn_mask, scale=10000.0, bias=-1.0, bias_after_scale=False)
n_head_self_attn_mask = fluid.layers.stack(
x=[self_attn_mask] * self._n_head, axis=1)
n_head_self_attn_mask.stop_gradient = True
self._enc_out = encoder(
enc_input=emb_out,
attn_bias=n_head_self_attn_mask,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._intermediate_size,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
hidden_act=self._hidden_act,
preprocess_cmd="",
postprocess_cmd="dan",
param_initializer=self._param_initializer,
name='encoder')
#def get_sequence_output(self):
# return self._enc_out
def get_pretraining_output(self, mask_label, mask_pos):
"""Get the loss & fc_out for training"""
mask_pos = fluid.layers.cast(x=mask_pos, dtype='int32')
reshaped_emb_out = fluid.layers.reshape(
x=self._enc_out, shape=[-1, self._emb_size])
# extract masked tokens' feature
mask_feat = fluid.layers.gather(input=reshaped_emb_out, index=mask_pos)
# transform: fc
mask_trans_feat = fluid.layers.fc(
input=mask_feat,
size=self._emb_size,
act=self._hidden_act,
param_attr=fluid.ParamAttr(
name='mask_lm_trans_fc.w_0',
initializer=self._param_initializer),
bias_attr=fluid.ParamAttr(name='mask_lm_trans_fc.b_0'))
# transform: layer norm
mask_trans_feat = pre_process_layer(
mask_trans_feat, 'n', name='mask_lm_trans')
mask_lm_out_bias_attr = fluid.ParamAttr(
name="mask_lm_out_fc.b_0",
initializer=fluid.initializer.Constant(value=0.0))
if self._weight_sharing:
fc_out = fluid.layers.matmul(
x=mask_trans_feat,
y=fluid.default_main_program().global_block().var(
self._word_emb_name),
transpose_y=True)
fc_out += fluid.layers.create_parameter(
shape=[self._voc_size],
dtype=self._dtype,
attr=mask_lm_out_bias_attr,
is_bias=True)
else:
fc_out = fluid.layers.fc(input=mask_trans_feat,
size=self._voc_size,
param_attr=fluid.ParamAttr(
name="mask_lm_out_fc.w_0",
initializer=self._param_initializer),
bias_attr=mask_lm_out_bias_attr)
#generate soft labels for loss cross entropy loss
one_hot_labels = fluid.layers.one_hot(
input=mask_label, depth=self._voc_size)
entity_indicator = fluid.layers.fill_constant_batch_size_like(
input=mask_label,
shape=[-1, (self._voc_size - self._n_relation)],
dtype='int64',
value=0)
relation_indicator = fluid.layers.fill_constant_batch_size_like(
input=mask_label,
shape=[-1, self._n_relation],
dtype='int64',
value=1)
is_relation = fluid.layers.concat(
input=[entity_indicator, relation_indicator], axis=-1)
soft_labels = one_hot_labels * self._soft_label \
+ (1.0 - one_hot_labels - is_relation) \
* ((1.0 - self._soft_label) / (self._voc_size - 1 - self._n_relation))
soft_labels.stop_gradient = True
mask_lm_loss = fluid.layers.softmax_with_cross_entropy(
logits=fc_out, label=soft_labels, soft_label=True)
mean_mask_lm_loss = fluid.layers.mean(mask_lm_loss)
return mean_mask_lm_loss, fc_out
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Transformer encoder."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from functools import partial, reduce
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.layer_helper import LayerHelper
def layer_norm(x,
begin_norm_axis=1,
epsilon=1e-12,
param_attr=None,
bias_attr=None):
"""
Replace build-in layer_norm op with this function
"""
helper = LayerHelper('layer_norm', **locals())
mean = layers.reduce_mean(x, dim=begin_norm_axis, keep_dim=True)
shift_x = layers.elementwise_sub(x=x, y=mean, axis=0)
variance = layers.reduce_mean(
layers.square(shift_x), dim=begin_norm_axis, keep_dim=True)
r_stdev = layers.rsqrt(variance + epsilon)
norm_x = layers.elementwise_mul(x=shift_x, y=r_stdev, axis=0)
param_shape = [reduce(lambda x, y: x * y, norm_x.shape[begin_norm_axis:])]
param_dtype = norm_x.dtype
scale = helper.create_parameter(
attr=param_attr,
shape=param_shape,
dtype=param_dtype,
default_initializer=fluid.initializer.Constant(1.))
bias = helper.create_parameter(
attr=bias_attr,
shape=param_shape,
dtype=param_dtype,
is_bias=True,
default_initializer=fluid.initializer.Constant(0.))
out = layers.elementwise_mul(x=norm_x, y=scale, axis=-1)
out = layers.elementwise_add(x=out, y=bias, axis=-1)
return out
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
param_initializer=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_query_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_query_fc.b_0')
k = layers.fc(input=keys,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_key_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_key_fc.b_0')
v = layers.fc(input=values,
size=d_value * n_head,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_value_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key**-0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_output_fc.w_0',
initializer=param_initializer),
bias_attr=name + '_output_fc.b_0')
return proj_out
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
param_initializer=None,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act=hidden_act,
param_attr=fluid.ParamAttr(
name=name + '_fc_0.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(
name=name + '_fc_1.w_0',
initializer=param_initializer),
bias_attr=name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out,
out,
process_cmd,
dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out_dtype = out.dtype
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float32")
out = layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_layer_norm_scale',
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name=name + '_layer_norm_bias',
initializer=fluid.initializer.Constant(0.)))
if out_dtype == fluid.core.VarDesc.VarType.FP16:
out = layers.cast(x=out, dtype="float16")
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
param_initializer=param_initializer,
name=name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
param_initializer=param_initializer,
name=name + '_ffn')
return post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
param_initializer=None,
name=''):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
param_initializer=param_initializer,
name=name + '_layer_' + str(i))
enc_input = enc_output
enc_output = pre_process_layer(
enc_output,
preprocess_cmd,
prepostprocess_dropout,
name="post_encoder")
return enc_output
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Optimization and learning rate scheduling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from utils.fp16 import create_master_params_grads, master_param_to_train_param
def linear_warmup_decay(learning_rate, warmup_steps, num_train_steps):
""" Applies linear warmup of learning rate from 0 and decay to 0."""
with fluid.default_main_program()._lr_schedule_guard():
lr = fluid.layers.tensor.create_global_var(
shape=[1],
value=0.0,
dtype='float32',
persistable=True,
name="scheduled_learning_rate")
global_step = fluid.layers.learning_rate_scheduler._decay_step_counter(
)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(global_step < num_train_steps * 0.1):
warmup_lr = learning_rate * (global_step /
(num_train_steps * 0.1))
fluid.layers.tensor.assign(warmup_lr, lr)
with switch.default():
decayed_lr = fluid.layers.learning_rate_scheduler.polynomial_decay(
learning_rate=learning_rate,
decay_steps=num_train_steps,
end_learning_rate=0.0,
power=1.0,
cycle=False)
fluid.layers.tensor.assign(decayed_lr, lr)
return lr
def optimization(loss,
warmup_steps,
num_train_steps,
learning_rate,
train_program,
startup_prog,
weight_decay,
scheduler='linear_warmup_decay',
use_fp16=False,
loss_scaling=1.0):
if warmup_steps > 0:
if scheduler == 'noam_decay':
scheduled_lr = fluid.layers.learning_rate_scheduler\
.noam_decay(1/(warmup_steps *(learning_rate ** 2)),
warmup_steps)
elif scheduler == 'linear_warmup_decay':
scheduled_lr = linear_warmup_decay(learning_rate, warmup_steps,
num_train_steps)
else:
raise ValueError("Unkown learning rate scheduler, should be "
"'noam_decay' or 'linear_warmup_decay'")
optimizer = fluid.optimizer.Adam(
learning_rate=scheduled_lr, epsilon=1e-6)
else:
optimizer = fluid.optimizer.Adam(
learning_rate=learning_rate, epsilon=1e-6)
scheduled_lr = learning_rate
clip_norm_thres = 1.0
# When using mixed precision training, scale the gradient clip threshold
# by loss_scaling
if use_fp16 and loss_scaling > 1.0:
clip_norm_thres *= loss_scaling
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres))
def exclude_from_weight_decay(name):
if name.find("layer_norm") > -1:
return True
bias_suffix = ["_bias", "_b", ".b_0"]
for suffix in bias_suffix:
if name.endswith(suffix):
return True
return False
param_list = dict()
if use_fp16:
param_grads = optimizer.backward(loss)
master_param_grads = create_master_params_grads(
param_grads, train_program, startup_prog, loss_scaling)
for param, _ in master_param_grads:
param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True
optimizer.apply_gradients(master_param_grads)
if weight_decay > 0:
for param, grad in master_param_grads:
# if exclude_from_weight_decay(param.name.rstrip(".master")):
# continue
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("weight_decay"):
updated_param = param - param_list[
param.name] * weight_decay * scheduled_lr
fluid.layers.assign(output=param, input=updated_param)
master_param_to_train_param(master_param_grads, param_grads,
train_program)
else:
for param in train_program.global_block().all_parameters():
param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True
_, param_grads = optimizer.minimize(loss)
if weight_decay > 0:
for param, grad in param_grads:
# if exclude_from_weight_decay(param.name):
# continue
with param.block.program._optimized_guard(
[param, grad]), fluid.framework.name_scope("weight_decay"):
updated_param = param - param_list[
param.name] * weight_decay * scheduled_lr
fluid.layers.assign(output=param, input=updated_param)
return scheduled_lr
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
data preprocess for pathquery datasets
"""
import os
import sys
import time
import logging
import argparse
from kbc_data_preprocess import write_vocab
from kbc_data_preprocess import load_vocab
from kbc_data_preprocess import generate_mask_type
from collections import defaultdict, Counter
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
inverted = lambda r: r[:2] == '**'
invert = lambda r: r[2:] if inverted(r) else '**' + r
class EvalDataset(object):
def __init__(self, train_file, test_file):
self.spo_train_fp = train_file
self.spo_test_fp = test_file
train_triples = self._load_spo_triples(self.spo_train_fp)
test_triples = self._load_spo_triples(self.spo_test_fp)
#logger.debug(">>train triples cnt:%d" % len(train_triples))
#logger.debug(">>test triples cnt:%d" % len(test_triples))
_train_cnt = len(train_triples)
all_triples = train_triples
all_triples.update(test_triples)
self.full_graph = Graph(all_triples)
logger.debug(self.full_graph)
def _load_spo_triples(self, spo_path):
"""
:param spo_path:
:return: set of (s,r,t) original triples
"""
logger.debug(">> Begin load base spo for %s at %s" %
(spo_path, time.ctime()))
triples = set()
for line in open(spo_path):
segs = line.strip().split("\t")
assert len(segs) == 3
s, p, o = segs
triples.add((s, p, o))
logger.debug(">> Loaded spo triples :%s cnt:%d" %
(spo_path, len(triples)))
logger.debug(">> End load spo for %s at %s" % (spo_path, time.ctime()))
return triples
class Graph(object):
def __init__(self, triples):
self.triples = triples
neighbors = defaultdict(lambda: defaultdict(set))
relation_args = defaultdict(lambda: defaultdict(set))
logger.info(">> Begin building graph at %s" % (time.ctime()))
self._node_set = set()
for s, r, t in triples:
relation_args[r]['s'].add(s)
relation_args[r]['t'].add(t)
neighbors[s][r].add(t)
neighbors[t][invert(r)].add(s)
self._node_set.add(t)
self._node_set.add(s)
def freeze(d):
frozen = {}
for key, subdict in d.iteritems():
frozen[key] = {}
for subkey, set_val in subdict.iteritems():
frozen[key][subkey] = tuple(set_val)
return frozen
self.neighbors = freeze(neighbors)
self.relation_args = freeze(relation_args)
logger.info(">> Done building graph at %s" % (time.ctime()))
def __repr__(self):
s = ""
s += "graph.relations_args cnt %d\t" % len(self.relation_args)
s += "graph.neighbors cnt %d\t" % len(self.neighbors)
s += "graph.neighbors node set cnt %d" % len(self._node_set)
return s
def walk_all(self, start, path):
"""
walk from start and get all the paths
:param start: start entity
:param path: (r1, r2, ...,rk)
:return: entities set for candidates path
"""
set_s = set()
set_t = set()
set_s.add(start)
for _, r in enumerate(path):
if len(set_s) == 0:
return set()
for _s in set_s:
if _s in self.neighbors and r in self.neighbors[_s]:
_tset = set(self.neighbors[_s][r]) #tupe to set
set_t.update(_tset)
set_s = set_t.copy()
set_t.clear()
return set_s
def repr_walk_all_ret(self, start, path, MAX_T=20):
cand_set = self.walk_all(start, path)
if len(cand_set) == 0:
return ">>start{} path:{} end: EMPTY!".format(
start, "->".join(list(path)))
_len = len(cand_set) if len(cand_set) < MAX_T else MAX_T
cand_node_str = ", ".join(cand_set[:_len])
return ">>start{} path:{} end: {}".format(
start, "->".join(list(path)), cand_node_str)
def type_matching_entities(self, path, position="t"):
assert (position == "t")
if position == "t":
r = path[-1]
elif position == "s":
r = path[0]
else:
logger.error(">>UNKNOWN position at type_matching_entities")
raise ValueError(position)
try:
if not inverted(r):
return r, self.relation_args[r][position]
else:
inv_pos = 's' if position == "t" else "t"
return r, self.relation_args[invert(r)][inv_pos]
except KeyError:
logger.error(
">>UNKNOWN path value at type_matching_entities :%s from path:%s"
% (r, path))
return None, tuple()
def is_trival_query(self, start, path):
"""
:param path:
:return: Boolean if True/False, is all candidates are right answers, return True
"""
#todo: check right again
cand_set = self.type_matching_entities(path, "t")
ans_set = self.walk_all(start, path)
_set = cand_set - ans_set
if len(_set) == 0:
return True
else:
return False
def get_unique_entities_relations(train_file, dev_file, test_file):
entity_lst = dict()
relation_lst = dict()
all_files = [train_file, dev_file, test_file]
for input_file in all_files:
with open(input_file, "r") as f:
for line in f.readlines():
tokens = line.strip().split("\t")
assert len(tokens) == 3
entity_lst[tokens[0]] = len(entity_lst)
entity_lst[tokens[2]] = len(entity_lst)
relations = tokens[1].split(",")
for relation in relations:
relation_lst[relation] = len(relation_lst)
print(">> Number of unique entities: %s" % len(entity_lst))
print(">> Number of unique relations: %s" % len(relation_lst))
return entity_lst, relation_lst
def filter_base_data(raw_train_file, raw_dev_file, raw_test_file,
train_base_file, dev_base_file, test_base_file):
def fil_base(input_file, output_file):
fout = open(output_file, "w")
base_n = 0
with open(input_file, "r") as f:
for line in f.readlines():
tokens = line.strip().split("\t")
assert len(tokens) == 3
relations = tokens[1].split(",")
if len(relations) == 1:
fout.write(line)
base_n += 1
fout.close()
return base_n
train_base_n = fil_base(raw_train_file, train_base_file)
dev_base_n = fil_base(raw_dev_file, dev_base_file)
test_base_n = fil_base(raw_test_file, test_base_file)
print(">> Train base cnt:%d" % train_base_n)
print(">> Valid base cnt:%d" % dev_base_n)
print(">> Test base cnt:%d" % test_base_n)
def generate_onlytail_mask_type(input_file, output_file):
with open(output_file, "w") as fw:
with open(input_file, "r") as fr:
for line in fr.readlines():
fw.write(line.strip('\r \n') + "\tMASK_TAIL\n")
def generate_eval_files(vocab_path, raw_test_file, train_base_file,
dev_base_file, test_base_file, sen_candli_file,
trivial_sen_file):
token2id = load_vocab(vocab_path)
eval_data = EvalDataset(train_base_file, test_base_file)
fout_sen_cand = open(sen_candli_file, "w")
fout_q_trival = open(trivial_sen_file, "w")
sen_candli_cnt = trivial_sen_cnt = 0
j = 0
for line in open(raw_test_file):
line = line.strip()
j += 1
segs = line.split("\t")
s = segs[0]
t = segs[2]
path = tuple(segs[1].split(","))
q_set = eval_data.full_graph.walk_all(s, path)
r, cand_set = eval_data.full_graph.type_matching_entities(path, "t")
cand_set = set(cand_set)
neg_set = cand_set - q_set
sen_tokens = []
sen_tokens.append(line.split("\t")[0])
sen_tokens.extend(line.split("\t")[1].split(","))
sen_tokens.append(line.split("\t")[2])
sen_id = [str(token2id[x]) for x in sen_tokens]
if len(neg_set) == 0:
trivial_sen_cnt += 1
#fout_q_trival.write(line + "\n")
fout_q_trival.write(" ".join(sen_id) + "\n")
else:
sen_candli_cnt += 1
candli_id_set = [str(token2id[x]) for x in neg_set]
sen_canli_str = "%s\t%s" % (" ".join(sen_id),
" ".join(list(candli_id_set)))
fout_sen_cand.write(sen_canli_str + "\n")
if len(cand_set) < len(q_set):
logger.error("ERROR! cand_set %d < q_set %d at line[%d]:%s" %
(len(cand_set), len(q_set), j, line))
if j % 100 == 0:
logger.debug(" ...processing %d at %s" % (j, time.ctime()))
if -100 > 0 and j >= 100:
break
logger.info(">> sen_canli_set count:%d " % sen_candli_cnt)
logger.info(">> trivial sen count:%d " % trivial_sen_cnt)
logger.info(">> Finish generate evaluation candidates for %s file at %s" %
(raw_test_file, time.ctime()))
def pathquery_data_preprocess(raw_train_file, raw_dev_file, raw_test_file,
vocab_path, sen_candli_file, trivial_sen_file,
new_train_file, new_dev_file, new_test_file,
train_base_file, dev_base_file, test_base_file):
entity_lst, relation_lst = get_unique_entities_relations(
raw_train_file, raw_dev_file, raw_test_file)
write_vocab(vocab_path, entity_lst, relation_lst)
filter_base_data(raw_train_file, raw_dev_file, raw_test_file,
train_base_file, dev_base_file, test_base_file)
generate_mask_type(raw_train_file, new_train_file)
generate_onlytail_mask_type(raw_dev_file, new_dev_file)
generate_onlytail_mask_type(raw_test_file, new_test_file)
vocab = load_vocab(vocab_path)
generate_eval_files(vocab_path, raw_test_file, train_base_file,
dev_base_file, test_base_file, sen_candli_file,
trivial_sen_file)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--task",
type=str,
required=True,
default=None,
help="task name: fb15k, fb15k237, wn18rr, wn18, pathqueryFB, pathqueryWN"
)
parser.add_argument(
"--dir",
type=str,
required=True,
default=None,
help="task data directory")
parser.add_argument(
"--train",
type=str,
required=False,
default="train",
help="train file name, default train.txt")
parser.add_argument(
"--valid",
type=str,
required=False,
default="dev",
help="valid file name, default valid.txt")
parser.add_argument(
"--test",
type=str,
required=False,
default="test",
help="test file name, default test.txt")
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
task = args.task.lower()
assert task in ["pathqueryfb", "pathquerywn"]
raw_train_file = os.path.join(args.dir, args.train)
raw_dev_file = os.path.join(args.dir, args.valid)
raw_test_file = os.path.join(args.dir, args.test)
new_train_file = os.path.join(args.dir, "train.coke.txt")
new_test_file = os.path.join(args.dir, "test.coke.txt")
new_dev_file = os.path.join(args.dir, "dev.coke.txt")
vocab_file = os.path.join(args.dir, "vocab.txt")
sen_candli_file = os.path.join(args.dir, "sen_candli.txt")
trivial_sen_file = os.path.join(args.dir, "trivial_sen.txt")
train_base_file = os.path.join(args.dir, "train.base.txt")
test_base_file = os.path.join(args.dir, "test.base.txt")
dev_base_file = os.path.join(args.dir, "dev.base.txt")
pathquery_data_preprocess(raw_train_file, raw_dev_file, raw_test_file,
vocab_file, sen_candli_file, trivial_sen_file,
new_train_file, new_dev_file, new_test_file,
train_base_file, dev_base_file, test_base_file)
"""Mask, padding and batching."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def mask(input_tokens, input_mask_type, max_len, mask_id):
"""
Add mask for batch_tokens, return out, mask_label, mask_pos;
Note: mask_pos responding the batch_tokens after padded;
"""
output_tokens = []
mask_label = []
mask_pos = []
for sent_index, sent in enumerate(input_tokens):
mask_type = input_mask_type[sent_index]
if mask_type == "MASK_HEAD":
token_index = 0
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
sent_out = sent[:]
sent_out[token_index] = mask_id
output_tokens.append(sent_out)
elif mask_type == "MASK_TAIL":
token_index = len(sent) - 1
mask_label.append(sent[token_index])
mask_pos.append(sent_index * max_len + token_index)
sent_out = sent[:]
sent_out[token_index] = mask_id
output_tokens.append(sent_out)
else:
raise ValueError(
"Unknown mask type, which should be in ['MASK_HEAD', 'MASK_TAIL']."
)
mask_label = np.array(mask_label).astype("int64").reshape([-1, 1])
mask_pos = np.array(mask_pos).astype("int64").reshape([-1, 1])
return output_tokens, mask_label, mask_pos
def pad_batch_data(insts,
max_len,
pad_idx=0,
return_pos=False,
return_input_mask=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and input mask.
"""
return_list = []
# Any token included in dict can be used to pad, since the paddings' loss
# will be masked out by weights and make no effect on parameter gradients.
inst_data = np.array([
list(inst) + list([pad_idx] * (max_len - len(inst))) for inst in insts
])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
# position data
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts
])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_input_mask:
# This is used to avoid attention on paddings.
input_mask_data = np.array([[1] * len(inst) + [0] *
(max_len - len(inst)) for inst in insts])
input_mask_data = np.expand_dims(input_mask_data, axis=-1)
return_list += [input_mask_data.astype("float32")]
return return_list if len(return_list) > 1 else return_list[0]
def prepare_batch_data(insts, max_len, pad_id=None, mask_id=None):
""" masking, padding, turn list data into numpy arrays, for batch examples
"""
batch_src_ids = [inst[0] for inst in insts]
batch_mask_type = [inst[1] for inst in insts]
# First step: do mask without padding
if mask_id >= 0:
out, mask_label, mask_pos = mask(
input_tokens=batch_src_ids,
input_mask_type=batch_mask_type,
max_len=max_len,
mask_id=mask_id)
else:
out = batch_src_ids
# Second step: padding and turn into numpy arrays
src_id, pos_id, input_mask = pad_batch_data(
out,
max_len=max_len,
pad_idx=pad_id,
return_pos=True,
return_input_mask=True)
if mask_id >= 0:
return_list = [src_id, pos_id, input_mask, mask_label, mask_pos]
else:
return_list = [src_id, pos_id, input_mask]
return return_list if len(return_list) > 1 else return_list[0]
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" data reader for CoKE
"""
from __future__ import print_function
from __future__ import division
import numpy as np
import six
import collections
import logging
from reader.batching import prepare_batch_data
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S')
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.info(logger.getEffectiveLevel())
RawExample = collections.namedtuple("RawExample", ["token_ids", "mask_type"])
def convert_to_unicode(text):
"""Converts `text` to Unicode (if it's not already), assuming utf-8 input."""
if six.PY3:
if isinstance(text, str):
return text
elif isinstance(text, bytes):
return text.decode("utf-8", "ignore")
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
elif six.PY2:
if isinstance(text, str):
return text.decode("utf-8", "ignore")
elif isinstance(text, unicode):
return text
else:
raise ValueError("Unsupported string type: %s" % (type(text)))
else:
raise ValueError("Not running on Python2 or Python 3?")
#def printable_text(text):
# """Returns text encoded in a way suitable for print or `tf.logging`."""
#
# # These functions want `str` for both Python2 and Python3, but in one case
# # it's a Unicode string and in the other it's a byte string.
# if six.PY3:
# if isinstance(text, str):
# return text
# elif isinstance(text, bytes):
# return text.decode("utf-8", "ignore")
# else:
# raise ValueError("Unsupported string type: %s" % (type(text)))
# elif six.PY2:
# if isinstance(text, str):
# return text
# elif isinstance(text, unicode):
# return text.encode("utf-8")
# else:
# raise ValueError("Unsupported string type: %s" % (type(text)))
# else:
# raise ValueError("Not running on Python2 or Python 3?")
def load_vocab(vocab_file):
"""Loads a vocabulary file into a dictionary."""
vocab = collections.OrderedDict()
fin = open(vocab_file)
for num, line in enumerate(fin):
items = line.strip().split("\t")
if len(items) > 2:
break
token = items[0]
index = items[1] if len(items) == 2 else num
token = token.strip()
vocab[token] = int(index)
return vocab
#def convert_by_vocab(vocab, items):
# """Converts a sequence of [tokens|ids] using the vocab."""
# output = []
# for item in items:
# output.append(vocab[item])
# return output
def convert_tokens_to_ids(vocab, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
output = []
for item in tokens:
output.append(vocab[item])
return output
class KBCDataReader(object):
""" DataReader
"""
def __init__(self,
vocab_path,
data_path,
max_seq_len=3,
batch_size=4096,
is_training=True,
shuffle=True,
dev_count=1,
epoch=10,
vocab_size=-1):
self.vocab = load_vocab(vocab_path)
if vocab_size > 0:
assert len(self.vocab) == vocab_size, \
"Assert Error! Input vocab_size(%d) is not consistant with voab_file(%d)" % \
(vocab_size, len(self.vocab))
self.pad_id = self.vocab["[PAD]"]
self.mask_id = self.vocab["[MASK]"]
self.max_seq_len = max_seq_len
self.batch_size = batch_size
self.is_training = is_training
self.shuffle = shuffle
self.dev_count = dev_count
self.epoch = epoch
if not is_training:
self.shuffle = False
self.dev_count = 1
self.epoch = 1
self.examples = self.read_example(data_path)
self.total_instance = len(self.examples)
self.current_epoch = -1
self.current_instance_index = -1
def get_progress(self):
"""return current progress of traning data
"""
return self.current_instance_index, self.current_epoch
def line2tokens(self, line):
tokens = line.split("\t")
return tokens
def read_example(self, input_file):
"""Reads the input file into a list of examples."""
examples = []
with open(input_file, "r") as f:
for line in f.readlines():
line = convert_to_unicode(line.strip())
tokens = self.line2tokens(line)
assert len(tokens) <= (self.max_seq_len + 1), \
"Expecting at most [max_seq_len + 1]=%d tokens each line, current tokens %d" \
% (self.max_seq_len + 1, len(tokens))
token_ids = convert_tokens_to_ids(self.vocab, tokens[:-1])
if len(token_ids) <= 0:
continue
examples.append(
RawExample(
token_ids=token_ids, mask_type=tokens[-1]))
# if len(examples) <= 10:
# logger.info("*** Example ***")
# logger.info("tokens: %s" % " ".join([printable_text(x) for x in tokens]))
# logger.info("token_ids: %s" % " ".join([str(x) for x in token_ids]))
return examples
def data_generator(self):
""" wrap the batch data generator
"""
range_list = [i for i in range(self.total_instance)]
def wrapper():
""" wrapper batch data
"""
def reader():
for epoch_index in range(self.epoch):
self.current_epoch = epoch_index
if self.shuffle is True:
np.random.shuffle(range_list)
for idx, sample in enumerate(range_list):
self.current_instance_index = idx
yield self.examples[sample]
def batch_reader(reader, batch_size):
"""reader generator for batches of examples
:param reader: reader generator for one example
:param batch_size: int batch size
:return: a list of examples for batch data
"""
batch = []
for example in reader():
token_ids = example.token_ids
mask_type = example.mask_type
example_out = [token_ids] + [mask_type]
to_append = len(batch) < batch_size
if to_append is False:
yield batch
batch = [example_out]
else:
batch.append(example_out)
if len(batch) > 0:
yield batch
all_device_batches = []
for batch_data in batch_reader(reader, self.batch_size):
batch_data = prepare_batch_data(
batch_data,
max_len=self.max_seq_len,
pad_id=self.pad_id,
mask_id=self.mask_id)
if len(all_device_batches) < self.dev_count:
all_device_batches.append(batch_data)
if len(all_device_batches) == self.dev_count:
for batch in all_device_batches:
yield batch
all_device_batches = []
return wrapper
class PathqueryDataReader(KBCDataReader):
def __init__(self,
vocab_path,
data_path,
max_seq_len=3,
batch_size=4096,
is_training=True,
shuffle=True,
dev_count=1,
epoch=10,
vocab_size=-1):
KBCDataReader.__init__(self, vocab_path, data_path, max_seq_len,
batch_size, is_training, shuffle, dev_count,
epoch, vocab_size)
def line2tokens(self, line):
tokens = []
s, path, o, mask_type = line.split("\t")
path_tokens = path.split(",")
tokens.append(s)
tokens.extend(path_tokens)
tokens.append(o)
tokens.append(mask_type)
return tokens
此差异已折叠。
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Arguments for configuration."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
import logging
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def str2bool(v):
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_arguments(args):
logger.info('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
logger.info('%s: %s' % (arg, value))
logger.info('------------------------------------------------')
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import paddle.fluid as fluid
def cast_fp16_to_fp32(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP16,
"out_dtype": fluid.core.VarDesc.VarType.FP32
})
def cast_fp32_to_fp16(i, o, prog):
prog.global_block().append_op(
type="cast",
inputs={"X": i},
outputs={"Out": o},
attrs={
"in_dtype": fluid.core.VarDesc.VarType.FP32,
"out_dtype": fluid.core.VarDesc.VarType.FP16
})
def copy_to_master_param(p, block):
v = block.vars.get(p.name, None)
if v is None:
raise ValueError("no param name %s found!" % p.name)
new_p = fluid.framework.Parameter(
block=block,
shape=v.shape,
dtype=fluid.core.VarDesc.VarType.FP32,
type=v.type,
lod_level=v.lod_level,
stop_gradient=p.stop_gradient,
trainable=p.trainable,
optimize_attr=p.optimize_attr,
regularizer=p.regularizer,
gradient_clip_attr=p.gradient_clip_attr,
error_clip=p.error_clip,
name=v.name + ".master")
return new_p
def create_master_params_grads(params_grads, main_prog, startup_prog,
loss_scaling):
master_params_grads = []
tmp_role = main_prog._current_role
OpRole = fluid.core.op_proto_and_checker_maker.OpRole
main_prog._current_role = OpRole.Backward
for p, g in params_grads:
# create master parameters
master_param = copy_to_master_param(p, main_prog.global_block())
startup_master_param = startup_prog.global_block()._clone_variable(
master_param)
startup_p = startup_prog.global_block().var(p.name)
cast_fp16_to_fp32(startup_p, startup_master_param, startup_prog)
# cast fp16 gradients to fp32 before apply gradients
if g.name.find("layer_norm") > -1:
if loss_scaling > 1:
scaled_g = g / float(loss_scaling)
else:
scaled_g = g
master_params_grads.append([p, scaled_g])
continue
master_grad = fluid.layers.cast(g, "float32")
if loss_scaling > 1:
master_grad = master_grad / float(loss_scaling)
master_params_grads.append([master_param, master_grad])
main_prog._current_role = tmp_role
return master_params_grads
def master_param_to_train_param(master_params_grads, params_grads, main_prog):
for idx, m_p_g in enumerate(master_params_grads):
train_p, _ = params_grads[idx]
if train_p.name.find("layer_norm") > -1:
continue
with main_prog._optimized_guard([m_p_g[0], m_p_g[1]]):
cast_fp32_to_fp16(m_p_g[0], train_p, main_prog)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import six
import ast
import copy
import logging
import numpy as np
import paddle.fluid as fluid
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
level=logging.INFO)
logger = logging.getLogger(__name__)
def cast_fp32_to_fp16(exe, main_program):
logger.info("Cast parameters to float16 data format.")
for param in main_program.global_block().all_parameters():
if not param.name.endswith(".master"):
param_t = fluid.global_scope().find_var(param.name).get_tensor()
data = np.array(param_t)
if param.name.find("layer_norm") == -1:
param_t.set(np.float16(data).view(np.uint16), exe.place)
master_param_var = fluid.global_scope().find_var(param.name +
".master")
if master_param_var is not None:
master_param_var.get_tensor().set(data, exe.place)
def init_checkpoint(exe,
init_checkpoint_path,
main_program,
use_fp16=False,
print_var_verbose=False):
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
logger.info("Load model from {}".format(init_checkpoint_path))
if use_fp16:
cast_fp32_to_fp16(exe, main_program)
# Used for debug on parameters
if print_var_verbose is True:
def params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return True
existed_vars = list(filter(params, main_program.list_vars()))
existed_vars = sorted(existed_vars, key=lambda x: x.name)
for var in existed_vars:
logger.info("var name:{} shape:{}".format(var.name, var.shape))
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
logger.info("Load pretraining parameters from {}.".format(
pretraining_params_path))
if use_fp16:
cast_fp32_to_fp16(exe, main_program)
TASK=fb15k237
NUM_VOCAB=14878 #NUM_VOCAB and NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=237
# training hyper-paramters
BATCH_SIZE=512
LEARNING_RATE=5e-4
EPOCH=800
SOFT_LABEL=0.25
SKIP_STEPS=1000
MAX_SEQ_LEN=3
HIDDEN_DROPOUT_PROB=0.5
ATTENTION_PROBS_DROPOUT_PROB=0.5
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_${TASK}"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.coke.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
CHECKPOINTS="$OUTPUT/models"
INIT_CHECKPOINTS=$CHECKPOINTS
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=3
TASK=fb15k
NUM_VOCAB=16396 #NUM_VOCAB and NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=1345
# training hyper-paramters
BATCH_SIZE=512
LEARNING_RATE=5e-4
EPOCH=300
SOFT_LABEL=0.8
SKIP_STEPS=1000
MAX_SEQ_LEN=3
HIDDEN_DROPOUT_PROB=0.1
ATTENTION_PROBS_DROPOUT_PROB=0.1
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_${TASK}"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.coke.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
CHECKPOINTS="$OUTPUT/models"
INIT_CHECKPOINTS=$CHECKPOINTS
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=3
TASK=pathqueryFB
NUM_VOCAB=75169 #NUM_VOCAB and NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=26
# training hyper-paramters
BATCH_SIZE=2048
LEARNING_RATE=3e-4
EPOCH=10
SOFT_LABEL=1.0
SKIP_STEPS=200
MAX_SEQ_LEN=7
HIDDEN_DROPOUT_PROB=0.1
ATTENTION_PROBS_DROPOUT_PROB=0.1
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_$TASK"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.coke.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
SEN_CANDLI_PATH="$DATA/${TASK}/sen_candli.txt"
TRIVAL_SEN_PATH="$DATA/${TASK}/trivial_sen.txt"
CHECKPOINTS="$OUTPUT/models"
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=7
TASK=pathqueryWN
NUM_VOCAB=38673 #NUM_VOCAB and NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=22
# training hyper-paramters
BATCH_SIZE=2048
LEARNING_RATE=3e-4
EPOCH=5
SOFT_LABEL=1.0
SKIP_STEPS=1
MAX_SEQ_LEN=7
HIDDEN_DROPOUT_PROB=0.1
ATTENTION_PROBS_DROPOUT_PROB=0.1
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_${TASK}_debug"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
SEN_CANDLI_PATH="$DATA/${TASK}/sen_candli.txt"
TRIVAL_SEN_PATH="$DATA/${TASK}/trivial_sen.txt"
CHECKPOINTS="$OUTPUT/models"
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=7
TASK=wn18
NUM_VOCAB=41061 #NUM_VOCAB/NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=18
# training hyper-paramters
BATCH_SIZE=512
LEARNING_RATE=5e-4
EPOCH=500
SOFT_LABEL=0.2
SKIP_STEPS=1000
MAX_SEQ_LEN=3
HIDDEN_DROPOUT_PROB=0.1
ATTENTION_PROBS_DROPOUT_PROB=0.1
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_${TASK}"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.coke.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
CHECKPOINTS="$OUTPUT/models"
INIT_CHECKPOINTS=$CHECKPOINTS
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=3
TASK=wn18rr
NUM_VOCAB=41054 #NUM_VOCAB/NUM_RELATIONS must be consistent with vocab.txt file
NUM_RELATIONS=11
# training hyper-paramters
BATCH_SIZE=1024
LEARNING_RATE=3e-4
EPOCH=800
SOFT_LABEL=0.15
SKIP_STEPS=1000
MAX_SEQ_LEN=3
HIDDEN_DROPOUT_PROB=0.1
ATTENTION_PROBS_DROPOUT_PROB=0.1
# file paths for training and evaluation
DATA="./data"
OUTPUT="./output_${TASK}"
TRAIN_FILE="$DATA/${TASK}/train.coke.txt"
VALID_FILE="$DATA/${TASK}/valid.coke.txt"
TEST_FILE="$DATA/${TASK}/test.coke.txt"
VOCAB_PATH="$DATA/${TASK}/vocab.txt"
TRUE_TRIPLE_PATH="${DATA}/${TASK}/all.txt"
CHECKPOINTS="$OUTPUT/models"
INIT_CHECKPOINTS=$CHECKPOINTS
LOG_FILE="$OUTPUT/train.log"
LOG_EVAL_FILE="$OUTPUT/test.log"
# transformer net config, the follwoing are default configs for all tasks
HIDDEN_SIZE=256
NUM_HIDDEN_LAYERS=6
NUM_ATTENTION_HEADS=4
MAX_POSITION_EMBEDDINS=3
set -eu
set -o pipefail
# Attention! Python 2.7.14 and python3 gives different vocabulary order. We use Python 2.7.14 to preprocess files.
# input files: train.txt valid.txt test.txt
# (these are default filenames, change files name with the following arguments: --train $trainname --valid $validname --test $testname)
# output files: vocab.txt train.coke.txt valid.coke.txt test.coke.txt
python ./bin/kbc_data_preprocess.py --task fb15k --dir ./data/fb15k
python ./bin/kbc_data_preprocess.py --task wn18 --dir ./data/wn18
python ./bin/kbc_data_preprocess.py --task fb15k237 --dir ./data/fb15k237
python ./bin/kbc_data_preprocess.py --task wn18rr --dir ./data/wn18rr
# input files: train dev test
# (these are default filenames, change files name with the following arguments: --train $trainname --valid $validname --test $testname)
# output files: vocab.txt train.coke.txt valid.coke.txt test.coke.txt sen_candli.txt trivial_sen.txt
python ./bin/pathquery_data_preprocess.py --task pathqueryFB --dir ./data/pathqueryFB
python ./bin/pathquery_data_preprocess.py --task pathqueryWN --dir ./data/pathqueryWN
#! /bin/bash
#==========
set -e
set -x
set -u
set -o pipefail
#==========
#==========configs
conf_fp=$1
CUDA=$2
source $conf_fp
#=========init env
export CUDA_VISIBLE_DEVICES=$CUDA
export FLAGS_sync_nccl_allreduce=1
#modify to your own path
export LD_LIBRARY_PATH=$(pwd)/env/lib/nccl2.3.7_cuda9.0/lib:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64/:/home/work/cuda-9.0/lib64/:$LD_LIBRARY_PATH
#======beging train
if [ -d $OUTPUT ]; then
rm -rf $OUTPUT
fi
mkdir $OUTPUT
max_step_id=`ls $INIT_CHECKPOINTS | grep "step" | awk -F"_" '{print $NF}' | grep -v "Found" |sort -n |tail -1`
INIT_CHECKPOINT_STEP=${INIT_CHECKPOINTS}/step_${max_step_id}
echo "init_checkpoints_steps: $max_step_id"
#--init_checkpoint ${INIT_CHECKPOINT}
echo ">> Begin kbc test now, log file: $LOG_EVAL_FILE"
python3 -u ./bin/run.py \
--dataset $TASK \
--vocab_size $NUM_VOCAB \
--num_relations $NUM_RELATIONS \
--use_cuda true \
--do_train false \
--train_file $TRAIN_FILE \
--checkpoints $CHECKPOINTS \
--init_checkpoint ${INIT_CHECKPOINT_STEP} \
--true_triple_path $TRUE_TRIPLE_PATH \
--max_seq_len $MAX_SEQ_LEN \
--soft_label $SOFT_LABEL \
--batch_size $BATCH_SIZE \
--epoch $EPOCH \
--learning_rate $LEARNING_RATE \
--hidden_dropout_prob $HIDDEN_DROPOUT_PROB \
--attention_probs_dropout_prob $ATTENTION_PROBS_DROPOUT_PROB \
--skip_steps $SKIP_STEPS \
--do_predict true \
--predict_file $TEST_FILE \
--vocab_path $VOCAB_PATH \
--hidden_size $HIDDEN_SIZE \
--num_hidden_layers $NUM_HIDDEN_LAYERS \
--num_attention_heads $NUM_ATTENTION_HEADS \
--max_position_embeddings $MAX_POSITION_EMBEDDINS \
--use_ema false > $LOG_EVAL_FILE 2>&1
echo ">> Finish kbc test, log file: $LOG_EVAL_FILE"
#! /bin/bash
#==========
set -exu
set -o pipefail
#==========
#==========configs
CONF_FP=$1
CUDA_ID=$2
#=========init env
source $CONF_FP
export CUDA_VISIBLE_DEVICES=$CUDA_ID
export FLAGS_sync_nccl_allreduce=1
#modify to your own path
export LD_LIBRARY_PATH=$(pwd)/env/lib/nccl2.3.7_cuda9.0/lib:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64/:/home/work/cuda-9.0/lib64/:$LD_LIBRARY_PATH
#=========running paths
if [ -d $OUTPUT ]; then
rm -rf $OUTPUT
fi
mkdir $OUTPUT
#======beging train
echo ">> Begin kbc train now"
python3 -u ./bin/run.py \
--dataset $TASK \
--vocab_size $NUM_VOCAB \
--num_relations $NUM_RELATIONS \
--use_cuda true \
--do_train true \
--train_file $TRAIN_FILE \
--true_triple_path $TRUE_TRIPLE_PATH \
--max_seq_len $MAX_SEQ_LEN \
--checkpoints $CHECKPOINTS \
--soft_label $SOFT_LABEL \
--batch_size $BATCH_SIZE \
--epoch $EPOCH \
--learning_rate $LEARNING_RATE \
--hidden_dropout_prob $HIDDEN_DROPOUT_PROB \
--attention_probs_dropout_prob $ATTENTION_PROBS_DROPOUT_PROB \
--skip_steps $SKIP_STEPS \
--do_predict false \
--vocab_path $VOCAB_PATH \
--hidden_size $HIDDEN_SIZE \
--num_hidden_layers $NUM_HIDDEN_LAYERS \
--num_attention_heads $NUM_ATTENTION_HEADS \
--max_position_embeddings $MAX_POSITION_EMBEDDINS \
--use_ema false > $LOG_FILE 2>&1
#! /bin/bash
set -exu
set -o pipefail
#configs
CONF_FP=$1
CUDA_ID=$2
source $CONF_FP
export CUDA_VISIBLE_DEVICES=$CUDA_ID
export FLAGS_sync_nccl_allreduce=1
# todo: modify to your own path
export LD_LIBRARY_PATH=$(pwd)/env/lib/nccl2.3.7_cuda9.0/lib:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64/:/home/work/cuda-9.0/lib64/:$LD_LIBRARY_PATH
max_step_id=`ls $CHECKPOINTS | grep "step" | awk -F"_" '{print $NF}' | grep -v "Found" |sort -n |tail -1`
INIT_CHECKPOINT_STEP=${CHECKPOINTS}/step_${max_step_id}
echo "max_step_id: $max_step_id"
echo ">> Begin predict now"
python3 -u ./bin/run.py \
--dataset $TASK \
--vocab_size $NUM_VOCAB \
--num_relations $NUM_RELATIONS \
--use_cuda true \
--do_train false \
--do_predict true \
--predict_file $TEST_FILE \
--init_checkpoint ${INIT_CHECKPOINT_STEP} \
--batch_size $BATCH_SIZE \
--vocab_path $VOCAB_PATH \
--sen_candli_file $SEN_CANDLI_PATH \
--sen_trivial_file $TRIVAL_SEN_PATH \
--max_seq_len $MAX_SEQ_LEN \
--learning_rate $LEARNING_RATE \
--use_ema false > $LOG_EVAL_FILE 2>&1
#! /bin/bash
set -exu
set -o pipefail
#configs
CONF_FP=$1
CUDA_ID=$2
source $CONF_FP
export CUDA_VISIBLE_DEVICES=$CUDA_ID
export FLAGS_sync_nccl_allreduce=1
# todo: modify to your own path
export LD_LIBRARY_PATH=$(pwd)/env/lib/nccl2.3.7_cuda9.0/lib:/home/work/cudnn/cudnn_v7/cuda/lib64:/home/work/cuda-9.0/extras/CUPTI/lib64/:/home/work/cuda-9.0/lib64/:$LD_LIBRARY_PATH
# prepare output directory
if [ -d $OUTPUT ]; then
rm -rf $OUTPUT
fi
mkdir $OUTPUT
# begin training
echo ">> Begin train now"
python3 -u ./bin/run.py \
--dataset $TASK \
--vocab_size $NUM_VOCAB \
--num_relations $NUM_RELATIONS \
--use_cuda true \
--do_train true \
--do_predict false \
--train_file $TRAIN_FILE \
--predict_file $TEST_FILE \
--max_seq_len $MAX_SEQ_LEN \
--checkpoints $CHECKPOINTS \
--soft_label $SOFT_LABEL \
--batch_size $BATCH_SIZE \
--epoch $EPOCH \
--learning_rate $LEARNING_RATE \
--hidden_dropout_prob $HIDDEN_DROPOUT_PROB \
--attention_probs_dropout_prob $ATTENTION_PROBS_DROPOUT_PROB \
--skip_steps $SKIP_STEPS \
--vocab_path $VOCAB_PATH \
--hidden_size $HIDDEN_SIZE \
--sen_candli_file $SEN_CANDLI_PATH \
--sen_trivial_file $TRIVAL_SEN_PATH \
--num_hidden_layers $NUM_HIDDEN_LAYERS \
--num_attention_heads $NUM_ATTENTION_HEADS \
--max_position_embeddings $MAX_POSITION_EMBEDDINS \
--use_ema false > $LOG_FILE 2>&1
#!/bin/bash
mkdir data
pushd ./ && cd ./data
##downloads the 4 widely used KBC dataset
wget --no-check-certificate https://everest.hds.utc.fr/lib/exe/fetch.php?media=en:fb15k.tgz -O fb15k.tgz
wget --no-check-certificate https://everest.hds.utc.fr/lib/exe/fetch.php?media=en:wordnet-mlj12.tar.gz -O wordnet-mlj12.tar.gz
wget --no-check-certificat https://download.microsoft.com/download/8/7/0/8700516A-AB3D-4850-B4BB-805C515AECE1/FB15K-237.2.zip -O FB15K-237.2.zip
wget --no-check-certificat https://raw.githubusercontent.com/TimDettmers/ConvE/master/WN18RR.tar.gz -O WN18RR.tar.gz
##downloads the path query dataset
wget --no-check-certificate https://worksheets.codalab.org/rest/bundles/0xdb6b691c2907435b974850e8eb9a5fc2/contents/blob/ -O freebase_paths.tar.gz
wget --no-check-certificate https://worksheets.codalab.org/rest/bundles/0xf91669f6c6d74987808aeb79bf716bd0/contents/blob/ -O wordnet_paths.tar.gz
## organize the train/valid/test files by renaming
#fb15k
tar -xvf fb15k.tgz
mv FB15k fb15k
mv ./fb15k/freebase_mtr100_mte100-train.txt ./fb15k/train.txt
mv ./fb15k/freebase_mtr100_mte100-test.txt ./fb15k/test.txt
mv ./fb15k/freebase_mtr100_mte100-valid.txt ./fb15k/valid.txt
#wn18
tar -zxvf wordnet-mlj12.tar.gz && mv wordnet-mlj12 wn18
mv wn18/wordnet-mlj12-train.txt wn18/train.txt
mv wn18/wordnet-mlj12-test.txt wn18/test.txt
mv wn18/wordnet-mlj12-valid.txt wn18/valid.txt
#fb15k237
unzip FB15K-237.2.zip && mv Release fb15k237
#wn18rr
mkdir wn18rr && tar -zxvf WN18RR.tar.gz -C wn18rr
#pathqueryWN
mkdir pathqueryWN && tar -zxvf wordnet_paths.tar.gz -C pathqueryWN
#pathqueryFB
mkdir pathqueryFB && tar -zxvf freebase_paths.tar.gz -C pathqueryFB
##rm tmp zip files
# rm ./*.gz
# rm ./*.tgz
# rm ./*.zip
popd
wget --no-check-certificate https://baidu-kg.bj.bcebos.com/CoKE/kbc_models.tar.gz
tar -zxvf kbc_models.tar.gz
rm kbc_models.tar.gz
This work has been moved to new address: [PaddleKG](https://github.com/PaddlePaddle/Research/tree/master/KG)
# __Deep Attention Matching Network__
## 简介
### 任务说明
深度注意力机制模型(Deep Attention Matching Network)是开放领域多轮对话匹配模型。根据多轮对话历史和候选回复内容,排序出最合适的回复。
网络结构如下,更多内容可以参考论文:[http://aclweb.org/anthology/P18-1103](http://aclweb.org/anthology/P18-1103).
<p align="center">
<img src="images/Figure1.png"/> <br />
Overview of Deep Attention Matching Network
</p>
### 效果说明
该模型在两个公开数据集上效果如下:
<p align="center">
<img src="images/Figure2.png"/> <br />
</p>
同时推荐用户参考[IPython Notebook demo](https://aistudio.baidu.com/aistudio/projectDetail/122287)
## 快速开始
### 安装说明
1. paddle安装
本项目依赖于Paddle Fluid 1.3.1 及以上版本,请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装
2. 下载代码
克隆数据集代码库到本地
```
git clone https://github.com/PaddlePaddle/models.git
cd models/PaddleNLP/dialogue_model_toolkit/deep_attention_matching
```
3. 环境依赖
python版本依赖python 2.7
### 开始第一次模型调用
1. 数据准备
下载经过预处理的数据,运行该脚本之后,data目录下会存在ubuntu和douban两个文件夹。
```
cd data
sh download_data.sh
```
2. 模型训练
```
python -u main.py \
--do_train True \
--use_cuda \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files/ubuntu \
--use_pyreader \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 32
```
3. 模型评估
```
python -u main.py \
--do_test True \
--use_cuda \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files/ubuntu/step_372 \
--model_path ./model_files/ubuntu/step_372 \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 100
```
## 进阶使用
### 任务定义与建模
多轮对话匹配任务输入是多轮对话历史和候选回复,输出是回复匹配得分,根据匹配得分排序。
### 模型原理介绍
可以参考论文:[http://aclweb.org/anthology/P18-1103](http://aclweb.org/anthology/P18-1103).
### 数据格式说明
训练、预测、评估使用的数据示例如下,数据由三列组成,以制表符('\t')分隔,第一列是以空
格分开的上文id,第二列是以空格分开的回复id,第三列是标签
```
286 642 865 36 87 25 693 0
17 54 975 512 775 54 6 1
```
注:本项目额外提供了分词预处理脚本(在preprocess目录下),可供用户使用,具体使用方法如
下:
```
python tokenizer.py \
--test_data_dir ./test.txt.utf8 \
--batch_size 1 > test.txt.utf8.seg
```
### 代码结构说明
main.py:该项目的主函数,封装包括训练、预测的部分
config.py:定义了该项目模型的相关配置,包括具体模型类别、以及模型的超参数
reader.py:定义了读入数据,加载词典的功能
evaluation.py:定义评估函数
run.sh:训练、预测运行脚本
## 其他
如何贡献代码
如果你可以修复某个issue或者增加一个新功能,欢迎给我们提交PR。如果对应的PR被接受了,我们将根据贡献的质量和难度进行打分(0-5分,越高越好)。如果你累计获得了10分,可以联系我们获得面试机会或者为你写推荐信。
This work has been moved to new address: [NLP](https://github.com/PaddlePaddle/Research/tree/master/NLP)
# this file is only used for continuous evaluation test!
import os
import sys
sys.path.append(os.environ['ceroot'])
from kpi import CostKpi
from kpi import DurationKpi
train_cost_card1 = CostKpi('train_cost_card1', 0.02, 0, actived=True)
train_cost_card4 = CostKpi('train_cost_card4', 0.06, 0, actived=True)
train_duration_card1 = DurationKpi('train_duration_card1', 0.01, 0, actived=True)
train_duration_card4 = DurationKpi('train_duration_card4', 0.01, 0, actived=True)
tracking_kpis = [
train_cost_card1,
train_cost_card4,
train_duration_card1,
train_duration_card4,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_cost\t1.0
test_cost\t1.0
train_cost\t1.0
train_cost\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
log_to_ce(log)
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deep Attention Matching Network
"""
import argparse
import six
def parse_args():
"""
Deep Attention Matching Network Config
"""
parser = argparse.ArgumentParser("DAM Config")
parser.add_argument(
'--do_train',
type=bool,
default=False,
help='Whether to perform training.')
parser.add_argument(
'--do_test',
type=bool,
default=False,
help='Whether to perform training.')
parser.add_argument(
'--batch_size',
type=int,
default=256,
help='Batch size for training. (default: %(default)d)')
parser.add_argument(
'--num_scan_data',
type=int,
default=2,
help='Number of pass for training. (default: %(default)d)')
parser.add_argument(
'--learning_rate',
type=float,
default=1e-3,
help='Learning rate used to train. (default: %(default)f)')
parser.add_argument(
'--data_path',
type=str,
default="data/data_small.pkl",
help='Path to training data. (default: %(default)s)')
parser.add_argument(
'--save_path',
type=str,
default="saved_models",
help='Path to save trained models. (default: %(default)s)')
parser.add_argument(
'--model_path',
type=str,
default=None,
help='Path to load well-trained models. (default: %(default)s)')
parser.add_argument(
'--use_cuda',
action='store_true',
help='If set, use cuda for training.')
parser.add_argument(
'--use_pyreader',
action='store_true',
help='If set, use pyreader for reading data.')
parser.add_argument(
'--ext_eval',
action='store_true',
help='If set, use MAP, MRR ect for evaluation.')
parser.add_argument(
'--max_turn_num',
type=int,
default=9,
help='Maximum number of utterances in context.')
parser.add_argument(
'--max_turn_len',
type=int,
default=50,
help='Maximum length of setences in turns.')
parser.add_argument(
'--word_emb_init',
type=str,
default=None,
help='Path to the initial word embedding.')
parser.add_argument(
'--vocab_size',
type=int,
default=434512,
help='The size of vocabulary.')
parser.add_argument(
'--emb_size',
type=int,
default=200,
help='The dimension of word embedding.')
parser.add_argument(
'--_EOS_',
type=int,
default=28270,
help='The id for the end of sentence in vocabulary.')
parser.add_argument(
'--stack_num',
type=int,
default=5,
help='The number of stacked attentive modules in network.')
parser.add_argument(
'--channel1_num',
type=int,
default=32,
help="The channels' number of the 1st conv3d layer's output.")
parser.add_argument(
'--channel2_num',
type=int,
default=16,
help="The channels' number of the 2nd conv3d layer's output.")
args = parser.parse_args()
return args
def print_arguments(args):
"""
Print Config
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
ubuntu_url=http://dam-data.cdn.bcebos.com/ubuntu.tar.gz
ubuntu_md5=9d7db116a040530a16f68dc0ab44e4b6
if [ ! -e ubuntu.tar.gz ]; then
wget -c $ubuntu_url
fi
echo "Checking md5 sum ..."
md5sum_tmp=`md5sum ubuntu.tar.gz | cut -d ' ' -f1`
if [ $md5sum_tmp != $ubuntu_md5 ]; then
echo "Md5sum check failed, please remove and redownload ubuntu.tar.gz"
exit 1
fi
echo "Untar ubuntu.tar.gz ..."
tar -xzvf ubuntu.tar.gz
mv data ubuntu
douban_url=http://dam-data.cdn.bcebos.com/douban.tar.gz
douban_md5=e07ca68f21c20e09efb3e8b247194405
if [ ! -e douban.tar.gz ]; then
wget -c $douban_url
fi
echo "Checking md5 sum ..."
md5sum_tmp=`md5sum douban.tar.gz | cut -d ' ' -f1`
if [ $md5sum_tmp != $douban_md5 ]; then
echo "Md5sum check failed, please remove and redownload douban.tar.gz"
exit 1
fi
echo "Untar douban.tar.gz ..."
tar -xzvf douban.tar.gz
mv data douban
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Evaluation
"""
import sys
import six
import numpy as np
def evaluate_ubuntu(file_path):
"""
Evaluate on ubuntu data
"""
def get_p_at_n_in_m(data, n, m, ind):
"""
Recall n at m
"""
pos_score = data[ind][0]
curr = data[ind:ind + m]
curr = sorted(curr, key=lambda x: x[0], reverse=True)
if curr[n - 1][0] <= pos_score:
return 1
return 0
data = []
with open(file_path, 'r') as file:
for line in file:
line = line.strip()
tokens = line.split("\t")
if len(tokens) != 2:
continue
data.append((float(tokens[0]), int(tokens[1])))
#assert len(data) % 10 == 0
p_at_1_in_2 = 0.0
p_at_1_in_10 = 0.0
p_at_2_in_10 = 0.0
p_at_5_in_10 = 0.0
length = len(data) // 10
for i in six.moves.xrange(0, length):
ind = i * 10
assert data[ind][1] == 1
p_at_1_in_2 += get_p_at_n_in_m(data, 1, 2, ind)
p_at_1_in_10 += get_p_at_n_in_m(data, 1, 10, ind)
p_at_2_in_10 += get_p_at_n_in_m(data, 2, 10, ind)
p_at_5_in_10 += get_p_at_n_in_m(data, 5, 10, ind)
result_dict = {
"1_in_2": p_at_1_in_2 / length,
"1_in_10": p_at_1_in_10 / length,
"2_in_10": p_at_2_in_10 / length,
"5_in_10": p_at_5_in_10 / length
}
return result_dict
def evaluate_douban(file_path):
"""
Evaluate douban data
"""
def mean_average_precision(sort_data):
"""
Evaluate mean average precision
"""
count_1 = 0
sum_precision = 0
for index in six.moves.xrange(len(sort_data)):
if sort_data[index][1] == 1:
count_1 += 1
sum_precision += 1.0 * count_1 / (index + 1)
return sum_precision / count_1
def mean_reciprocal_rank(sort_data):
"""
Evaluate MRR
"""
sort_lable = [s_d[1] for s_d in sort_data]
assert 1 in sort_lable
return 1.0 / (1 + sort_lable.index(1))
def precision_at_position_1(sort_data):
"""
Evaluate precision
"""
if sort_data[0][1] == 1:
return 1
else:
return 0
def recall_at_position_k_in_10(sort_data, k):
""""
Evaluate recall
"""
sort_lable = [s_d[1] for s_d in sort_data]
select_lable = sort_lable[:k]
return 1.0 * select_lable.count(1) / sort_lable.count(1)
def evaluation_one_session(data):
"""
Evaluate one session
"""
sort_data = sorted(data, key=lambda x: x[0], reverse=True)
m_a_p = mean_average_precision(sort_data)
m_r_r = mean_reciprocal_rank(sort_data)
p_1 = precision_at_position_1(sort_data)
r_1 = recall_at_position_k_in_10(sort_data, 1)
r_2 = recall_at_position_k_in_10(sort_data, 2)
r_5 = recall_at_position_k_in_10(sort_data, 5)
return m_a_p, m_r_r, p_1, r_1, r_2, r_5
sum_m_a_p = 0
sum_m_r_r = 0
sum_p_1 = 0
sum_r_1 = 0
sum_r_2 = 0
sum_r_5 = 0
i = 0
total_num = 0
with open(file_path, 'r') as infile:
for line in infile:
if i % 10 == 0:
data = []
tokens = line.strip().split('\t')
data.append((float(tokens[0]), int(tokens[1])))
if i % 10 == 9:
total_num += 1
m_a_p, m_r_r, p_1, r_1, r_2, r_5 = evaluation_one_session(data)
sum_m_a_p += m_a_p
sum_m_r_r += m_r_r
sum_p_1 += p_1
sum_r_1 += r_1
sum_r_2 += r_2
sum_r_5 += r_5
i += 1
result_dict = {
"MAP": 1.0 * sum_m_a_p / total_num,
"MRR": 1.0 * sum_m_r_r / total_num,
"P_1": 1.0 * sum_p_1 / total_num,
"1_in_10": 1.0 * sum_r_1 / total_num,
"2_in_10": 1.0 * sum_r_2 / total_num,
"5_in_10": 1.0 * sum_r_5 / total_num
}
return result_dict
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Layers
"""
import paddle.fluid as fluid
def loss(x, y, clip_value=10.0):
"""Calculate the sigmoid cross entropy with logits for input(x).
Args:
x: Variable with shape with shape [batch, dim]
y: Input label
Returns:
loss: cross entropy
logits: prediction
"""
logits = fluid.layers.fc(
input=x,
size=1,
bias_attr=fluid.ParamAttr(initializer=fluid.initializer.Constant(0.)))
loss = fluid.layers.sigmoid_cross_entropy_with_logits(x=logits, label=y)
loss = fluid.layers.reduce_mean(
fluid.layers.clip(
loss, min=-clip_value, max=clip_value))
return loss, logits
def ffn(input, d_inner_hid, d_hid, name=None):
"""Position-wise Feed-Forward Network
"""
hidden = fluid.layers.fc(input=input,
size=d_inner_hid,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(name=name + '_fc.w_0'),
bias_attr=fluid.ParamAttr(
name=name + '_fc.b_0',
initializer=fluid.initializer.Constant(0.)),
act="relu")
out = fluid.layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=fluid.ParamAttr(name=name + '_fc.w_1'),
bias_attr=fluid.ParamAttr(
name=name + '_fc.b_1',
initializer=fluid.initializer.Constant(0.)))
return out
def dot_product_attention(query,
key,
value,
d_key,
q_mask=None,
k_mask=None,
dropout_rate=None,
mask_cache=None):
"""Dot product layer.
Args:
query: a tensor with shape [batch, Q_time, Q_dimension]
key: a tensor with shape [batch, time, K_dimension]
value: a tensor with shape [batch, time, V_dimension]
q_lengths: a tensor with shape [batch]
k_lengths: a tensor with shape [batch]
Returns:
a tensor with shape [batch, query_time, value_dimension]
Raises:
AssertionError: if Q_dimension not equal to K_dimension when attention
type is dot.
"""
logits = fluid.layers.matmul(
x=query, y=key, transpose_y=True, alpha=d_key**(-0.5))
if (q_mask is not None) and (k_mask is not None):
if mask_cache is not None and q_mask.name in mask_cache and k_mask.name in mask_cache[
q_mask.name]:
mask, another_mask = mask_cache[q_mask.name][k_mask.name]
else:
mask = fluid.layers.matmul(x=q_mask, y=k_mask, transpose_y=True)
another_mask = fluid.layers.scale(
mask,
scale=float(2**32 - 1),
bias=float(-1),
bias_after_scale=False)
if mask_cache is not None:
if q_mask.name not in mask_cache:
mask_cache[q_mask.name] = dict()
mask_cache[q_mask.name][k_mask.name] = [mask, another_mask]
logits = mask * logits + another_mask
attention = fluid.layers.softmax(logits)
if dropout_rate:
attention = fluid.layers.dropout(
input=attention, dropout_prob=dropout_rate, is_test=False, seed=2)
atten_out = fluid.layers.matmul(x=attention, y=value)
return atten_out
def block(name,
query,
key,
value,
d_key,
q_mask=None,
k_mask=None,
is_layer_norm=True,
dropout_rate=None,
mask_cache=None):
"""
Block
"""
att_out = dot_product_attention(
query,
key,
value,
d_key,
q_mask,
k_mask,
dropout_rate,
mask_cache=mask_cache)
y = query + att_out
if is_layer_norm:
y = fluid.layers.layer_norm(
input=y,
begin_norm_axis=len(y.shape) - 1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.),
name=name + '_layer_norm.w_0'),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.),
name=name + '_layer_norm.b_0'))
z = ffn(y, d_key, d_key, name)
w = y + z
if is_layer_norm:
w = fluid.layers.layer_norm(
input=w,
begin_norm_axis=len(w.shape) - 1,
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(1.),
name=name + '_layer_norm.w_1'),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(0.),
name=name + '_layer_norm.b_1'))
return w
def cnn_3d(input, out_channels_0, out_channels_1, add_relu=True):
"""
CNN-3d
"""
# same padding
conv_0 = fluid.layers.conv3d(
name="conv3d_0",
input=input,
num_filters=out_channels_0,
filter_size=[3, 3, 3],
padding=[1, 1, 1],
act="elu" if add_relu else None,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-0.01, high=0.01)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0)))
# same padding
pooling_0 = fluid.layers.pool3d(
input=conv_0,
pool_type="max",
pool_size=3,
pool_padding=1,
pool_stride=3)
conv_1 = fluid.layers.conv3d(
name="conv3d_1",
input=pooling_0,
num_filters=out_channels_1,
filter_size=[3, 3, 3],
padding=[1, 1, 1],
act="elu" if add_relu else None,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
low=-0.01, high=0.01)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0)))
# same padding
pooling_1 = fluid.layers.pool3d(
input=conv_1,
pool_type="max",
pool_size=3,
pool_padding=1,
pool_stride=3)
return pooling_1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deep Attention Matching Network
"""
import sys
import os
import six
import numpy as np
import time
import multiprocessing
import paddle
import paddle.fluid as fluid
import reader as reader
from util import mkdir
import evaluation as eva
import config
try:
import cPickle as pickle #python 2
except ImportError as e:
import pickle #python 3
from model_check import check_cuda
from net import Net
def evaluate(score_path, result_file_path):
"""
Evaluate both douban and ubuntu dataset
"""
if args.ext_eval:
result = eva.evaluate_douban(score_path)
else:
result = eva.evaluate_ubuntu(score_path)
#write evaluation result
with open(result_file_path, 'w') as out_file:
for p_at in result:
out_file.write(p_at + '\t' + str(result[p_at]) + '\n')
print('finish evaluation')
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
def test_with_feed(exe, program, feed_names, fetch_list, score_path, batches,
batch_num, dev_count):
"""
Test with feed
"""
score_file = open(score_path, 'w')
for it in six.moves.xrange(batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
batch_data = reader.make_one_batch_input(batches, val_index)
feed_dict = dict(zip(feed_names, batch_data))
feed_list.append(feed_dict)
predicts = exe.run(feed=feed_list, fetch_list=fetch_list)
scores = np.array(predicts[0])
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t' + str(
batches["label"][val_index][i]) + '\n')
score_file.close()
def test_with_pyreader(exe, program, pyreader, fetch_list, score_path, batches,
batch_num, dev_count):
"""
Test with pyreader
"""
def data_provider():
"""
Data reader
"""
for index in six.moves.xrange(batch_num):
yield reader.make_one_batch_input(batches, index)
score_file = open(score_path, 'w')
pyreader.decorate_tensor_provider(data_provider)
it = 0
pyreader.start()
while True:
try:
predicts = exe.run(fetch_list=fetch_list)
scores = np.array(predicts[0])
for dev in six.moves.xrange(dev_count):
val_index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t' + str(
batches["label"][val_index][i]) + '\n')
it += 1
except fluid.core.EOFException:
pyreader.reset()
break
score_file.close()
def train(args):
"""
Train Program
"""
if not os.path.exists(args.save_path):
os.makedirs(args.save_path)
# data data_config
data_conf = {
"batch_size": args.batch_size,
"max_turn_num": args.max_turn_num,
"max_turn_len": args.max_turn_len,
"_EOS_": args._EOS_,
}
dam = Net(args.max_turn_num, args.max_turn_len, args.vocab_size,
args.emb_size, args.stack_num, args.channel1_num,
args.channel2_num)
train_program = fluid.Program()
train_startup = fluid.Program()
if "CE_MODE_X" in os.environ:
train_program.random_seed = 110
train_startup.random_seed = 110
with fluid.program_guard(train_program, train_startup):
with fluid.unique_name.guard():
if args.use_pyreader:
train_pyreader = dam.create_py_reader(
capacity=10, name='train_reader')
else:
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
decay_steps=400,
decay_rate=0.9,
staircase=True))
optimizer.minimize(loss)
test_program = fluid.Program()
test_startup = fluid.Program()
if "CE_MODE_X" in os.environ:
test_program.random_seed = 110
test_startup.random_seed = 110
with fluid.program_guard(test_program, test_startup):
with fluid.unique_name.guard():
if args.use_pyreader:
test_pyreader = dam.create_py_reader(
capacity=10, name='test_reader')
else:
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
test_program = test_program.clone(for_test=True)
if args.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
print("device count %d" % dev_count)
print("theoretical memory usage: ")
print(fluid.contrib.memory_usage(
program=train_program, batch_size=args.batch_size))
exe = fluid.Executor(place)
exe.run(train_startup)
exe.run(test_startup)
train_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, loss_name=loss.name, main_program=train_program)
test_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda,
main_program=test_program,
share_vars_from=train_exe)
if args.word_emb_init is not None:
print("start loading word embedding init ...")
if six.PY2:
word_emb = np.array(pickle.load(open(args.word_emb_init,
'rb'))).astype('float32')
else:
word_emb = np.array(
pickle.load(
open(args.word_emb_init, 'rb'), encoding="bytes")).astype(
'float32')
dam.set_word_embedding(word_emb, place)
print("finish init word embedding ...")
print("start loading data ...")
with open(args.data_path, 'rb') as f:
if six.PY2:
train_data, val_data, test_data = pickle.load(f)
else:
train_data, val_data, test_data = pickle.load(f, encoding="bytes")
print("finish loading data ...")
val_batches = reader.build_batches(val_data, data_conf)
batch_num = len(train_data[six.b('y')]) // args.batch_size
val_batch_num = len(val_batches["response"])
print_step = max(1, batch_num // (dev_count * 100))
save_step = max(1, batch_num // (dev_count * 10))
print("begin model training ...")
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
def train_with_feed(step):
"""
Train on one epoch data by feeding
"""
ave_cost = 0.0
for it in six.moves.xrange(batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
index = it * dev_count + dev
batch_data = reader.make_one_batch_input(train_batches, index)
feed_dict = dict(zip(dam.get_feed_names(), batch_data))
feed_list.append(feed_dict)
cost = train_exe.run(feed=feed_list, fetch_list=[loss.name])
ave_cost += np.array(cost[0]).mean()
step = step + 1
if step % print_step == 0:
print("processed: [" + str(step * dev_count * 1.0 / batch_num) +
"] ave loss: [" + str(ave_cost / print_step) + "]")
ave_cost = 0.0
if (args.save_path is not None) and (step % save_step == 0):
save_path = os.path.join(args.save_path, "step_" + str(step))
print("Save model at step %d ... " % step)
print(time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path, 'score.' + str(step))
test_with_feed(test_exe, test_program,
dam.get_feed_names(), [logits.name], score_path,
val_batches, val_batch_num, dev_count)
result_file_path = os.path.join(args.save_path,
'result.' + str(step))
evaluate(score_path, result_file_path)
return step, np.array(cost[0]).mean()
def train_with_pyreader(step):
"""
Train on one epoch with pyreader
"""
def data_provider():
"""
Data reader
"""
for index in six.moves.xrange(batch_num):
yield reader.make_one_batch_input(train_batches, index)
train_pyreader.decorate_tensor_provider(data_provider)
ave_cost = 0.0
train_pyreader.start()
while True:
try:
cost = train_exe.run(fetch_list=[loss.name])
ave_cost += np.array(cost[0]).mean()
step = step + 1
if step % print_step == 0:
print("processed: [" + str(step * dev_count * 1.0 /
batch_num) + "] ave loss: [" +
str(ave_cost / print_step) + "]")
ave_cost = 0.0
if (args.save_path is not None) and (step % save_step == 0):
save_path = os.path.join(args.save_path,
"step_" + str(step))
print("Save model at step %d ... " % step)
print(time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime(time.time())))
fluid.io.save_persistables(exe, save_path, train_program)
score_path = os.path.join(args.save_path,
'score.' + str(step))
test_with_pyreader(test_exe, test_program, test_pyreader,
[logits.name], score_path, val_batches,
val_batch_num, dev_count)
result_file_path = os.path.join(args.save_path,
'result.' + str(step))
evaluate(score_path, result_file_path)
except fluid.core.EOFException:
train_pyreader.reset()
break
return step, np.array(cost[0]).mean()
# train over different epoches
global_step, train_time = 0, 0.0
for epoch in six.moves.xrange(args.num_scan_data):
shuffle_train = reader.unison_shuffle(
train_data, seed=110 if ("CE_MODE_X" in os.environ) else None)
train_batches = reader.build_batches(shuffle_train, data_conf)
begin_time = time.time()
if args.use_pyreader:
global_step, last_cost = train_with_pyreader(global_step)
else:
global_step, last_cost = train_with_feed(global_step)
pass_time_cost = time.time() - begin_time
train_time += pass_time_cost
print("Pass {0}, pass_time_cost {1}"
.format(epoch, "%2.2f sec" % pass_time_cost))
# For internal continuous evaluation
if "CE_MODE_X" in os.environ:
card_num = get_cards()
print("kpis\ttrain_cost_card%d\t%f" % (card_num, last_cost))
print("kpis\ttrain_duration_card%d\t%f" % (card_num, train_time))
def test(args):
"""
Test
"""
if not os.path.exists(args.save_path):
mkdir(args.save_path)
if not os.path.exists(args.model_path):
raise ValueError("Invalid model init path %s" % args.model_path)
# data data_config
data_conf = {
"batch_size": args.batch_size,
"max_turn_num": args.max_turn_num,
"max_turn_len": args.max_turn_len,
"_EOS_": args._EOS_,
}
dam = Net(args.max_turn_num, args.max_turn_len, args.vocab_size,
args.emb_size, args.stack_num, args.channel1_num,
args.channel2_num)
dam.create_data_layers()
loss, logits = dam.create_network()
loss.persistable = True
logits.persistable = True
# gradient clipping
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByValue(
max=1.0, min=-1.0))
test_program = fluid.default_main_program().clone(for_test=True)
optimizer = fluid.optimizer.Adam(
learning_rate=fluid.layers.exponential_decay(
learning_rate=args.learning_rate,
decay_steps=400,
decay_rate=0.9,
staircase=True))
optimizer.minimize(loss)
if args.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
#dev_count = multiprocessing.cpu_count()
dev_count = 1
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
fluid.io.load_persistables(exe, args.model_path)
test_exe = fluid.ParallelExecutor(
use_cuda=args.use_cuda, main_program=test_program)
print("start loading data ...")
with open(args.data_path, 'rb') as f:
if six.PY2:
train_data, val_data, test_data = pickle.load(f)
else:
train_data, val_data, test_data = pickle.load(f, encoding="bytes")
print("finish loading data ...")
test_batches = reader.build_batches(test_data, data_conf)
test_batch_num = len(test_batches["response"])
print("test batch num: %d" % test_batch_num)
print("begin inference ...")
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
score_path = os.path.join(args.save_path, 'score.txt')
score_file = open(score_path, 'w')
for it in six.moves.xrange(test_batch_num // dev_count):
feed_list = []
for dev in six.moves.xrange(dev_count):
index = it * dev_count + dev
batch_data = reader.make_one_batch_input(test_batches, index)
feed_dict = dict(zip(dam.get_feed_names(), batch_data))
feed_list.append(feed_dict)
predicts = test_exe.run(feed=feed_list, fetch_list=[logits.name])
scores = np.array(predicts[0])
print("step = %d" % it)
for dev in six.moves.xrange(dev_count):
index = it * dev_count + dev
for i in six.moves.xrange(args.batch_size):
score_file.write(
str(scores[args.batch_size * dev + i][0]) + '\t' + str(
test_batches["label"][index][i]) + '\n')
score_file.close()
#write evaluation result
if args.ext_eval:
result = eva.evaluate_douban(score_path)
else:
result = eva.evaluate_ubuntu(score_path)
result_file_path = os.path.join(args.save_path, 'result.txt')
with open(result_file_path, 'w') as out_file:
for metric in result:
out_file.write(metric + '\t' + str(result[metric]) + '\n')
print('finish test')
print(time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())))
def get_cards():
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
if __name__ == '__main__':
args = config.parse_args()
config.print_arguments(args)
check_cuda(args.use_cuda)
if args.do_train:
train(args)
if args.do_test:
test(args)
#encoding=utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import paddle
import paddle.fluid as fluid
def check_cuda(use_cuda, err = \
"\nYou can not set use_cuda = True in the model because you are using paddlepaddle-cpu.\n \
Please: 1. Install paddlepaddle-gpu to run your models on GPU or 2. Set use_cuda = False to run models on CPU.\n"
):
try:
if use_cuda == True and fluid.is_compiled_with_cuda() == False:
print(err)
sys.exit(1)
except Exception as e:
pass
if __name__ == "__main__":
check_cuda(True)
check_cuda(False)
check_cuda(True, "This is only for testing.")
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Deep Attention Matching Network
"""
import six
import numpy as np
import paddle.fluid as fluid
import layers
class Net(object):
"""
Deep attention matching network
"""
def __init__(self, max_turn_num, max_turn_len, vocab_size, emb_size,
stack_num, channel1_num, channel2_num):
"""
Init
"""
self._max_turn_num = max_turn_num
self._max_turn_len = max_turn_len
self._vocab_size = vocab_size
self._emb_size = emb_size
self._stack_num = stack_num
self._channel1_num = channel1_num
self._channel2_num = channel2_num
self._feed_names = []
self.word_emb_name = "shared_word_emb"
self.use_stack_op = True
self.use_mask_cache = True
self.use_sparse_embedding = True
def create_py_reader(self, capacity, name):
"""
Create py reader
"""
# turns ids
shapes = [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
dtypes = ["int64" for i in six.moves.xrange(self._max_turn_num)]
# turns mask
shapes += [[-1, self._max_turn_len, 1]
for i in six.moves.xrange(self._max_turn_num)]
dtypes += ["float32" for i in six.moves.xrange(self._max_turn_num)]
# response ids, response mask, label
shapes += [[-1, self._max_turn_len, 1], [-1, self._max_turn_len, 1],
[-1, 1]]
dtypes += ["int64", "float32", "float32"]
py_reader = fluid.layers.py_reader(
capacity=capacity,
shapes=shapes,
lod_levels=[0] * (2 * self._max_turn_num + 3),
dtypes=dtypes,
name=name,
use_double_buffer=True)
data_vars = fluid.layers.read_file(py_reader)
self.turns_data = data_vars[0:self._max_turn_num]
self.turns_mask = data_vars[self._max_turn_num:2 * self._max_turn_num]
self.response = data_vars[-3]
self.response_mask = data_vars[-2]
self.label = data_vars[-1]
return py_reader
def create_data_layers(self):
"""
Create data layer
"""
self._feed_names = []
self.turns_data = []
for i in six.moves.xrange(self._max_turn_num):
name = "turn_%d" % i
turn = fluid.layers.data(
name=name, shape=[self._max_turn_len, 1], dtype="int64")
self.turns_data.append(turn)
self._feed_names.append(name)
self.turns_mask = []
for i in six.moves.xrange(self._max_turn_num):
name = "turn_mask_%d" % i
turn_mask = fluid.layers.data(
name=name, shape=[self._max_turn_len, 1], dtype="float32")
self.turns_mask.append(turn_mask)
self._feed_names.append(name)
self.response = fluid.layers.data(
name="response", shape=[self._max_turn_len, 1], dtype="int64")
self.response_mask = fluid.layers.data(
name="response_mask",
shape=[self._max_turn_len, 1],
dtype="float32")
self.label = fluid.layers.data(name="label", shape=[1], dtype="float32")
self._feed_names += ["response", "response_mask", "label"]
def get_feed_names(self):
"""
Return feed names
"""
return self._feed_names
def set_word_embedding(self, word_emb, place):
"""
Set word embedding
"""
word_emb_param = fluid.global_scope().find_var(
self.word_emb_name).get_tensor()
word_emb_param.set(word_emb, place)
def create_network(self):
"""
Create network
"""
mask_cache = dict() if self.use_mask_cache else None
response_emb = fluid.layers.embedding(
input=self.response,
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1)))
# response part
Hr = response_emb
Hr_stack = [Hr]
for index in six.moves.xrange(self._stack_num):
Hr = layers.block(
name="response_self_stack" + str(index),
query=Hr,
key=Hr,
value=Hr,
d_key=self._emb_size,
q_mask=self.response_mask,
k_mask=self.response_mask,
mask_cache=mask_cache)
Hr_stack.append(Hr)
# context part
sim_turns = []
for t in six.moves.xrange(self._max_turn_num):
Hu = fluid.layers.embedding(
input=self.turns_data[t],
size=[self._vocab_size + 1, self._emb_size],
is_sparse=self.use_sparse_embedding,
param_attr=fluid.ParamAttr(
name=self.word_emb_name,
initializer=fluid.initializer.Normal(scale=0.1)))
Hu_stack = [Hu]
for index in six.moves.xrange(self._stack_num):
# share parameters
Hu = layers.block(
name="turn_self_stack" + str(index),
query=Hu,
key=Hu,
value=Hu,
d_key=self._emb_size,
q_mask=self.turns_mask[t],
k_mask=self.turns_mask[t],
mask_cache=mask_cache)
Hu_stack.append(Hu)
# cross attention
r_a_t_stack = []
t_a_r_stack = []
for index in six.moves.xrange(self._stack_num + 1):
t_a_r = layers.block(
name="t_attend_r_" + str(index),
query=Hu_stack[index],
key=Hr_stack[index],
value=Hr_stack[index],
d_key=self._emb_size,
q_mask=self.turns_mask[t],
k_mask=self.response_mask,
mask_cache=mask_cache)
r_a_t = layers.block(
name="r_attend_t_" + str(index),
query=Hr_stack[index],
key=Hu_stack[index],
value=Hu_stack[index],
d_key=self._emb_size,
q_mask=self.response_mask,
k_mask=self.turns_mask[t],
mask_cache=mask_cache)
t_a_r_stack.append(t_a_r)
r_a_t_stack.append(r_a_t)
t_a_r_stack.extend(Hu_stack)
r_a_t_stack.extend(Hr_stack)
if self.use_stack_op:
t_a_r = fluid.layers.stack(t_a_r_stack, axis=1)
r_a_t = fluid.layers.stack(r_a_t_stack, axis=1)
else:
for index in six.moves.xrange(len(t_a_r_stack)):
t_a_r_stack[index] = fluid.layers.unsqueeze(
input=t_a_r_stack[index], axes=[1])
r_a_t_stack[index] = fluid.layers.unsqueeze(
input=r_a_t_stack[index], axes=[1])
t_a_r = fluid.layers.concat(input=t_a_r_stack, axis=1)
r_a_t = fluid.layers.concat(input=r_a_t_stack, axis=1)
# sim shape: [batch_size, 2*(stack_num+1), max_turn_len, max_turn_len]
sim = fluid.layers.matmul(
x=t_a_r, y=r_a_t, transpose_y=True, alpha=1 / np.sqrt(200.0))
sim_turns.append(sim)
if self.use_stack_op:
sim = fluid.layers.stack(sim_turns, axis=2)
else:
for index in six.moves.xrange(len(sim_turns)):
sim_turns[index] = fluid.layers.unsqueeze(
input=sim_turns[index], axes=[2])
# sim shape: [batch_size, 2*(stack_num+1), max_turn_num, max_turn_len, max_turn_len]
sim = fluid.layers.concat(input=sim_turns, axis=2)
final_info = layers.cnn_3d(sim, self._channel1_num, self._channel2_num)
loss, logits = layers.loss(final_info, self.label)
return loss, logits
此差异已折叠。
export CUDA_VISIBLE_DEVICES=3
export FLAGS_eager_delete_tensor_gb=0.0
#train on ubuntu
python -u main.py \
--do_train True \
--use_cuda \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files/ubuntu \
--use_pyreader \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 32
#test on ubuntu
python -u main.py \
--do_test True \
--use_cuda \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files/ubuntu/step_31 \
--model_path ./model_files/ubuntu/step_31 \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 100
#train on douban
python -u main.py \
--do_train True \
--use_cuda \
--data_path ./data/douban/data_small.pkl \
--save_path ./model_files/douban \
--use_pyreader \
--vocab_size 172130 \
--_EOS_ 1 \
--channel1_num 16 \
--batch_size 32
#test on douban
python -u main.py \
--do_test True \
--use_cuda \
--ext_eval \
--data_path ./data/douban/data_small.pkl \
--save_path ./model_files/douban/step_31 \
--model_path ./model_files/douban/step_31 \
--vocab_size 172130 \
--_EOS_ 1 \
--channel1_num 16 \
--batch_size 32
export CPU_NUM=1
export FLAGS_eager_delete_tensor_gb=0.0
#train on ubuntu
python -u main.py \
--do_train True \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files_cpu/ubuntu \
--use_pyreader \
--stack_num 2 \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 32
#test on ubuntu
python -u main.py \
--do_test True \
--data_path ./data/ubuntu/data_small.pkl \
--save_path ./model_files_cpu/ubuntu/step_31 \
--model_path ./model_files_cpu/ubuntu/step_31 \
--stack_num 2 \
--vocab_size 434512 \
--_EOS_ 28270 \
--batch_size 40
#train on douban
python -u main.py \
--do_train True \
--data_path ./data/douban/data_small.pkl \
--save_path ./model_files_cpu/douban \
--use_pyreader \
--stack_num 2 \
--vocab_size 172130 \
--_EOS_ 1 \
--channel1_num 16 \
--batch_size 32
#test on douban
python -u main.py \
--do_test True \
--ext_eval \
--data_path ./data/douban/data_small.pkl \
--save_path ./model_files_cpu/douban/step_31 \
--model_path ./model_files_cpu/douban/step_31 \
--stack_num 2 \
--vocab_size 172130 \
--_EOS_ 1 \
--channel1_num 16 \
--batch_size 40
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Utils
"""
import six
import os
def print_arguments(args):
"""
Print arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def mkdir(path):
"""
Mkdir
"""
if not os.path.isdir(path):
if os.path.split(path)[0]:
mkdir(os.path.split(path)[0])
else:
return
os.mkdir(path)
def pos_encoding_init():
"""
Pos encoding init
"""
pass
def scaled_dot_product_attention():
"""
Scaleed dot product attention
"""
pass
此差异已折叠。
#!/bin/bash
# ==============================================================================
# Copyright 2017 Baidu.com, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# download preprocessed data
wget -c --no-check-certificate https://baidu-nlp.bj.bcebos.com/dureader_machine_reading-dataset-2.0.0.tar.gz
# download trained model parameters and vocabulary
wget -c --no-check-certificate https://baidu-nlp.bj.bcebos.com/dureader_machine_reading-bidaf-1.0.0.tar.gz
# decompression
tar -zxvf dureader_machine_reading-dataset-2.0.0.tar.gz
tar -zxvf dureader_machine_reading-bidaf-1.0.0.tar.gz
ln -s trained_model_para/vocab ./
ln -s trained_model_para/saved_model ./
02065c4adec24860c52b0192a3615d8d dureader_machine_reading-bidaf-1.0.0.tar.gz
a6b4678ccf319a0c8812b80aed71c24b dureader_machine_reading-dataset-2.0.0.tar.gz
#!/bin/bash
train(){
python -u run.py \
--pass_num 1 \
--learning_rate 0.001 \
--batch_size 8 \
--embed_size 300 \
--hidden_size 150 \
--max_p_num 5 \
--max_p_len 500 \
--max_q_len 60 \
--max_a_len 200 \
--enable_ce \
--train
}
cudaid=${single:=0} # use 0-th card as default
export CUDA_VISIBLE_DEVICES=$cudaid
train | python _ce.py
cudaid=${multi:=0,1,2,3} # use 0,1,2,3 card as default
export CUDA_VISIBLE_DEVICES=$cudaid
train | python _ce.py
# The notes on the updates of PaddlePaddle baseline
## Updates
We implement a BiDAF model with PaddlePaddle. Note that we have an update on the PaddlePaddle baseline (Feb 25, 2019). In this document, we give the details of the major updates:
### 1 Paragraph Extraction
The first update is that we incorporate a strategy of paragraph extraction to improve the model performance (see the file `paddle/para_extraction.py`). A similar strategy has been used in the Top-1 system (Liu et al. 2018) at [2018 Machine Reading Challenge](http://mrc2018.cipsc.org.cn/).
The original baseline of DuReader (He et al. 2018) employed a simple strategy to select paragraphs for model training and testing. However, the paragraphs that includes the true answers may not be selected. Hence, we want to incorporate as much information for the answer extraction as possible.
The detail of the new strategy of paragraph extraction is as follows. We apply the new paragraph extraction strategy on each document. For each document,
- We remove the duplicated paragraphs in the document.
- We concatenate the title and all paragraphs in the document with a pre-defined splitter if it is shorter than a predefined maximum length. Otherwise,
- We compute F1 score of each paragraph relative to the question;
- We concatenate the title and the top-K paragraphs (by F1 score) with a pre-defined splitter to form an extracted paragraph that should be shorter than the predefined maximum length.
### 2 The Prior of Document Ranking
We also introduce the prior of document ranking from search engine (see line #176 in `paddle/run.py`). The documents in DuReader are collected from the search results. Hence, the prior scores of document ranking is an important feature. We compute the prior scores from the training data and apply the prior scores in the testing stage.
## Reference
- Liu, J., Wei, W., Sun, M., Chen, H., Du, Y. and Lin, D., 2018. A Multi-answer Multi-task Framework for Real-world Machine Reading Comprehension. In Proceedings of the 2018 Conference on Empirical Methods in Natural Language Processing (pp. 2109-2118).
- He, W., Liu, K., Liu, J., Lyu, Y., Zhao, S., Xiao, X., Liu, Y., Wang, Y., Wu, H., She, Q. and Liu, X., 2017. Dureader: a chinese machine reading comprehension dataset from real-world applications. arXiv preprint arXiv:1711.05073.
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import json
import pandas as pd
if __name__ == '__main__':
if len(sys.argv) != 3:
print('Usage: tojson.py <input_path> <output_path>')
exit()
infile = sys.argv[1]
outfile = sys.argv[2]
df = pd.read_json(infile)
with open(outfile, 'w') as f:
for row in df.iterrows():
f.write(row[1].to_json() + '\n')
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册