提交 fa20bc49 编写于 作者: W wenquan wu 提交者: Yibing Liu

wwqydy patch 1 (#2411)

* Update README.md

* add ACL2019-DuConv
上级 38388a5e
knowledge-driven-dialogue
Proactive Conversation
=============================
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
# about the competition
Human-machine conversation is one of the most important topics in artificial intelligence (AI) and has received much attention across academia and industry in recent years. Currently dialogue system is still in its infancy, which usually converses passively and utters their words more as a matter of response rather than on their own initiatives, which is different from human-human conversation. Therefore, we set up this competition on a new conversation task, named knowledge driven dialogue, where machines converse with humans based on a built knowledge graph. It aims at testing machines’ ability to conduct human-like conversations.<br>
Please refer to [competition website](http://lic2019.ccf.org.cn/talk) for details of the competition.
# about the task
Given a dialogue goal g and a set of topic-related background knowledge M = f<sub>1</sub> ,f<sub>2</sub> ,..., f<sub>n</sub> , a participating system is expected to output an utterance "u<sub>t</sub>" for the current conversation H = u<sub>1</sub>, u<sub>2</sub>, ..., u<sub>t-1</sub>, which keeps the conversation coherent and informative under the guidance of the given goal. During the dialogue, a participating system is required to proactively lead the conversation from one topic to another. The dialog goal g is given like this: "Start->Topic_A->TOPIC_B", which means the machine should lead the conversation from any start state to topic A and then to topic B. The given background knowledge includes knowledge related to topic A and topic B, and the relations between these two topics.<br>
Please refer to [task description](https://github.com/baidu/knowledge-driven-dialogue/blob/master/task_description.pdf) for details of the task.
# about the baseline
We provide retrieval-based and generation-based baseline systems. Both systems were implemented by [PaddlePaddle](http://paddlepaddle.org/) (the Baidu deeplearning framework) and [Pytorch](https://pytorch.org/) (the Facebook deeplearning framework). The performance of the two systems is as follows:
# Motivation
Human-machine conversation is one of the most important topics in artificial intelligence (AI) and has received much attention across academia and industry in recent years. Currently dialogue system is still in its infancy, which usually converses passively and utters their words more as a matter of response rather than on their own initiatives, which is different from human-human conversation. We believe that the ability of proactive conversation of machine is the breakthrough of human-like conversation.
# What we do ?
* We set up a new conversation task, named ___Proactive Converstion___, where machine proactively leads the conversation following a given goal.
* We also created a new conversation dataset named [DuConv](https://ai.baidu.com/broad/subordinate?dataset=duconv) , and made it publicly available to facilitate the development of proactive conversation systems.
* We established retrival-based and generation-based ___baseline systems___ for DuConv, which are available in this repo.
* In addition, we hold ___competitions___ to encourage more researchers to work in this direction.
# Paper
* [Proactive Human-Machine Conversation with Explicit Conversation Goals](https://arxiv.org/abs/1906.05572), accepted by ACL 2019
# Task Description
Given a dialogue goal g and a set of topic-related background knowledge M = f<sub>1</sub> ,f<sub>2</sub> ,..., f<sub>n</sub> , the system is expected to output an utterance "u<sub>t</sub>" for the current conversation H = u<sub>1</sub>, u<sub>2</sub>, ..., u<sub>t-1</sub>, which keeps the conversation coherent and informative under the guidance of the given goal. During the dialogue, the system is required to proactively lead the conversation from one topic to another. The dialog goal g is given like this: "Start->Topic_A->TOPIC_B", which means the machine should lead the conversation from any start state to topic A and then to topic B. The given background knowledge includes knowledge related to topic A and topic B, and the relations between these two topics.<br>
![image](https://github.com/PaddlePaddle/models/blob/wwqydy-patch-1/PaddleNLP/Research/ACL2019-DuConv/images/proactive_conversation_case.png)
*Figure1.Proactive Conversation Case. Each utterance of "BOT" could be predicted by system, e.g., utterances with black words represent history H,and utterance with green words represent the response u<sub>t</sub> predicted by system.*
# DuConv
We collected around 30k conversations containing 270k utterances named DuConv. Each conversation was created by two random selected crowdsourced workers. One worker was provided with dialogue goal and the associated knowledge to play the role of leader who proactively leads the conversation by sequentially change the discussion topics following the given goal, meanwhile keeping the conversation as natural and engaging as possible. Another worker was provided with nothing but conversation history and only has to respond to the leader. <br>
  We devide the collected conversations into training, development, test1 and test2 splits. The test1 part with reference response is used for local testing such as the automatic evaluation of our paper. The test2 part without reference response is used for online testing such as the [competition](http://lic2019.ccf.org.cn/talk) we had held and the ___Leader Board___ which is opened forever in https://ai.baidu.com/broad/leaderboard?dataset=duconv. The dataset is available at https://ai.baidu.com/broad/subordinate?dataset=duconv.
# Baseline Performance
We provide retrieval-based and generation-based baseline systems. Both systems were implemented by [PaddlePaddle](http://paddlepaddle.org/) (the Baidu deeplearning framework). The performance of the two systems is as follows:
| baseline system | F1/BLEU1/BLEU2 | DISTINCT1/DISTINCT2 |
| ------------- | ------------ | ------------ |
| retrieval-based | 31.72/0.291/0.156 | 0.118/0.373 |
| generation-based | 32.65/0.300/0.168 | 0.062/0.128 |
# Competitions
* [Knowledge-driven Dialogue task](http://lic2019.ccf.org.cn/talk) in [2019 Language and Intelligence Challenge](http://lic2019.ccf.org.cn/), has been closed.
* Teams number of registration:1536
* Teams number of submission result: 178
* The Top 3 results:
| Rank | F1/BLEU1/BLEU2 | DISTINCT1/DISTINCT2 |
| ------------- | ------------ | ------------ |
| 1 | 49.22/0.449/0.318 | 0.118/0.299 |
| 2 | 47.76/0.430/0.296 | 0.110/0.275 |
| 3 | 46.40/0.422/0.289 | 0.118/0.303 |
* [Leader Board](https://ai.baidu.com/broad/leaderboard?dataset=duconv), is opened forever <br>
We maintain a leader board which provides the official automatic evaluation. You can submit your result to https://ai.baidu.com/broad/submission?dataset=duconv to get the official result. Please make sure submit the result of test2 part.
\ No newline at end of file
Knowledge-driven Dialogue
=============================
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
This is a paddlepaddle implementation of generative-based model for knowledge-driven dialogue
## Requirements
* cuda=9.0
* cudnn=7.0
* python=2.7
* numpy
* paddlepaddle>=1.3.2
## Quickstart
### Step 1: Preprocess the data
Put the data of [DuConv](https://ai.baidu.com/broad/subordinate?dataset=duconv) under the data folder and rename them train/dev/test.txt:
```
./data/resource/train.txt
./data/resource/dev.txt
./data/resource/test.txt
```
### Step 2: Train the model
Train model with the following commands.
```bash
sh run_train.sh
```
### Step 3: Test the Model
Test model with the following commands.
```bash
sh run_test.sh
```
### Note !!!
* The script run_train.sh/run_test.sh shows all the processes including data processing and model training/testing. Be sure to read it carefully and follow it.
* The files in ./data and ./model is just empty file to show the structure of the document.
\ No newline at end of file
#!/bin/bash
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
# set gpu id to use
export CUDA_VISIBLE_DEVICES=0
# generalizes target_a/target_b of goal for all outputs, replaces them with slot mark
TOPIC_GENERALIZATION=1
# set python path according to your actual environment
pythonpath='python'
# the prefix of the file name used by the model, must be consistent with the configuration in network.py
prefix=demo
# put all data set that used and generated for testing under this folder: datapath
# for more details, please refer to the following data processing instructions
datapath=./data
# in test stage, you can eval dev.txt or test.txt
# the "dev.txt" and "test.txt" are the original data of DuConv and
# need to be placed in this folder: INPUT_PATH/resource/
# the following preprocessing will generate the actual data needed for model testing
# after testing, you can run eval.py to get the final eval score if the original data have answer
# DATA_TYPE = "dev" or "test"
datapart=dev
# ensure that each file is in the correct path
# 1. put the data of DuConv under this folder: datapath/resource/
# - the data provided consists of three parts: train.txt dev.txt test.txt
# - the train.txt and dev.txt are session data, the test.txt is sample data
# - in test stage, we just use the dev.txt or test.txt
# 2. the sample data extracted from session data is in this folder: datapath/resource/
# 3. the text file required by the model is in this folder: datapath
# 4. the topic file used to generalize data is in this directory: datapath
corpus_file=${datapath}/resource/${datapart}.txt
sample_file=${datapath}/resource/sample.${datapart}.txt
text_file=${datapath}/${prefix}.test
topic_file=${datapath}/${prefix}.test.topic
# step 1: if eval dev.txt, firstly have to convert session data to sample data
# if eval test.txt, we can use original test.txt of DuConv directly.
if [ "${datapart}"x = "test"x ]; then
sample_file=${corpus_file}
else
${pythonpath} ./tools/convert_session_to_sample.py ${corpus_file} ${sample_file}
fi
# step 2: convert sample data to text data required by the model
${pythonpath} ./tools/convert_conversation_corpus_to_model_text.py ${sample_file} ${text_file} ${topic_file} ${TOPIC_GENERALIZATION}
# step 3: predict by model
${pythonpath} -u network.py --run_type test \
--use_gpu True \
--batch_size 12 \
--use_posterior False \
--model_path ./models/best_model \
--output ./output/test.result
# step 4: replace slot mark generated during topic generalization with real text
${pythonpath} ./tools/topic_materialization.py ./output/test.result ./output/test.result.final ${topic_file}
# step 5: if the original file has answers, you can run the following command to get result
# if the original file not has answers, you can upload the ./output/test.result.final
# to the website(https://ai.baidu.com/broad/submission?dataset=duconv) to get the official automatic evaluation
${pythonpath} ./tools/convert_result_for_eval.py ${sample_file} ./output/test.result.final ./output/test.result.eval
${pythonpath} ./tools/eval.py ./output/test.result.eval
#!/bin/bash
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
# set gpu id to use
export CUDA_VISIBLE_DEVICES=0
# generalizes target_a/target_b of goal for all outputs, replaces them with slot mark
TOPIC_GENERALIZATION=1
# set python path according to your actual environment
pythonpath='python'
# the prefix of the file name used by the model, must be consistent with the configuration in network.py
prefix=demo
# put all data set that used and generated for training under this folder: datapath
# for more details, please refer to the following data processing instructions
datapath=./data
vocabpath=${datapath}/vocab.txt
# in train stage, use "train.txt" to train model, and use "dev.txt" to eval model
# the "train.txt" and "dev.txt" are the original data of DuConv and
# need to be placed in this folder: datapath/resource/
# the following preprocessing will generate the actual data needed for model training
# datatype = "train" or "dev"
datatype=(train dev)
# data preprocessing
for ((i=0; i<${#datatype[*]}; i++))
do
# ensure that each file is in the correct path
# 1. put the data of DuConv under this folder: datapath/resource/
# - the data provided consists of three parts: train.txt dev.txt test.txt
# - the train.txt and dev.txt are session data, the test.txt is sample data
# - in train stage, we just use the train.txt and dev.txt
# 2. the sample data extracted from session data is in this folder: datapath/resource/
# 3. the text file required by the model is in this folder: datapath
# 4. the topic file used to generalize data is in this directory: datapath
corpus_file=${datapath}/resource/${datatype[$i]}.txt
sample_file=${datapath}/resource/sample.${datatype[$i]}.txt
text_file=${datapath}/${prefix}.${datatype[$i]}
topic_file=${datapath}/${prefix}.${datatype[$i]}.topic
# step 1: firstly have to convert session data to sample data
${pythonpath} ./tools/convert_session_to_sample.py ${corpus_file} ${sample_file}
# step 2: convert sample data to text data required by the model
${pythonpath} ./tools/convert_conversation_corpus_to_model_text.py ${sample_file} ${text_file} ${topic_file} ${TOPIC_GENERALIZATION}
# step 3: build vocabulary from the training data
if [ "${datatype[$i]}"x = "train"x ]; then
${pythonpath} ./tools/build_vocabulary.py ${text_file} ${vocabpath}
fi
done
# step 4: in train stage, we just use train.txt and dev.txt, so we copy dev.txt to test.txt for model training
cp ${datapath}/${prefix}.dev ${datapath}/${prefix}.test
# step 5: train model in two stage, you can find the model file in ./models/ after training
# step 5.1: stage 0, you can get model_stage_0.npz and opt_state_stage_0.npz in save_dir after stage 0
${pythonpath} -u network.py --run_type train \
--stage 0 \
--use_gpu True \
--pretrain_epoch 5 \
--batch_size 32 \
--use_posterior True \
--save_dir ./models \
--vocab_path ${vocabpath} \
--embed_file ./data/sgns.weibo.300d.txt
# step 5.2: stage 1, init the model and opt state using the result of stage 0 and train the model
${pythonpath} -u network.py --run_type train \
--stage 1 \
--use_gpu True \
--init_model ./models/model_stage_0.npz \
--init_opt_state ./models/opt_state_stage_0.npz \
--num_epochs 12 \
--batch_size 24 \
--use_posterior True \
--save_dir ./models \
--vocab_path ${vocabpath}
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: __init__.py
"""
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: source/inputters/corpus.py
"""
import re
import os
import random
import numpy as np
class KnowledgeCorpus(object):
""" Corpus """
def __init__(self,
data_dir,
data_prefix,
vocab_path,
min_len,
max_len):
self.data_dir = data_dir
self.data_prefix = data_prefix
self.vocab_path = vocab_path
self.min_len = min_len
self.max_len = max_len
self.current_train_example = -1
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.load_voc()
def filter_pred(ids):
"""
src_filter_pred
"""
return self.min_len <= len(ids) <= max_len
self.filter_pred = lambda ex: filter_pred(ex['src']) and filter_pred(ex['tgt'])
def load_voc(self):
""" load vocabulary """
idx = 0
self.vocab_dict = dict()
with open(self.vocab_path, 'r') as fr:
for line in fr:
line = line.strip()
self.vocab_dict[line] = idx
idx += 1
def read_data(self, data_file):
""" read_data """
data = []
with open(data_file, "r") as f:
for line in f:
if line.rstrip('\n').split('\t') < 3:
continue
src, tgt, knowledge = line.rstrip('\n').split('\t')[:3]
filter_knowledge = []
for sent in knowledge.split('\1'):
filter_knowledge.append(' '.join(sent.split()[: self.max_len]))
data.append({'src': src, 'tgt': tgt, 'cue':filter_knowledge})
return data
def tokenize(self, tokens):
""" map tokens to ids """
if isinstance(tokens, str):
tokens = re.sub('\d+', '<num>', tokens).lower()
toks = tokens.split(' ')
toks_ids = [self.vocab_dict.get('<bos>')] + \
[self.vocab_dict.get(tok, self.vocab_dict.get('<unk>'))
for tok in toks] + \
[self.vocab_dict.get('<eos>')]
return toks_ids
elif isinstance(tokens, list):
tokens_list = [self.tokenize(t) for t in tokens]
return tokens_list
def build_examples(self, data):
""" build examples, data: ``List[Dict]`` """
examples = []
for raw_data in data:
example = {}
for name, strings in raw_data.items():
example[name] = self.tokenize(strings)
if not self.filter_pred(example):
continue
examples.append((example['src'], example['tgt'], example['cue']))
return examples
def preprocessing_for_lines(self, lines, batch_size):
""" preprocessing for lines """
raw_data = []
for line in lines:
src, tgt, knowledge = line.rstrip('\n').split('\t')[:3]
filter_knowledge = []
for sent in knowledge.split('\1'):
filter_knowledge.append(' '.join(sent.split()[: self.max_len]))
raw_data.append({'src': src, 'tgt': tgt, 'cue': filter_knowledge})
examples = self.build_examples(raw_data)
def instance_reader():
""" instance reader """
for (index, example) in enumerate(examples):
instance = [example[0], example[1], example[2]]
yield instance
def batch_reader(reader, batch_size):
""" batch reader """
batch = []
for instance in reader():
if len(batch) < batch_size:
batch.append(instance)
else:
yield batch
batch = [instance]
if len(batch) > 0:
yield batch
def wrapper():
""" wrapper """
for batch in batch_reader(instance_reader, batch_size):
batch_data = self.prepare_batch_data(batch)
yield batch_data
return wrapper
def data_generator(self, batch_size, phase, shuffle=False):
""" Generate data for train, dev or test. """
if phase == 'train':
train_file = os.path.join(self.data_dir, self.data_prefix + ".train")
train_raw = self.read_data(train_file)
examples = self.build_examples(train_raw)
self.num_examples['train'] = len(examples)
elif phase == 'dev':
valid_file = os.path.join(self.data_dir, self.data_prefix + ".dev")
valid_raw = self.read_data(valid_file)
examples = self.build_examples(valid_raw)
self.num_examples['dev'] = len(examples)
elif phase == 'test':
test_file = os.path.join(self.data_dir, self.data_prefix + ".test")
test_raw = self.read_data(test_file)
examples = self.build_examples(test_raw)
self.num_examples['test'] = len(examples)
else:
raise ValueError(
"Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
""" instance reader """
if shuffle:
random.shuffle(examples)
for (index, example) in enumerate(examples):
if phase == 'train':
self.current_train_example = index + 1
instance = [example[0], example[1], example[2]]
yield instance
def batch_reader(reader, batch_size):
""" batch reader """
batch = []
for instance in reader():
if len(batch) < batch_size:
batch.append(instance)
else:
yield batch
batch = [instance]
if len(batch) > 0:
yield batch
def wrapper():
""" wrapper """
for batch in batch_reader(instance_reader, batch_size):
batch_data = self.prepare_batch_data(batch)
yield batch_data
return wrapper
def prepare_batch_data(self, batch):
""" generate input tensor data """
batch_source_ids = [inst[0] for inst in batch]
batch_target_ids = [inst[1] for inst in batch]
batch_knowledge_ids = [inst[2] for inst in batch]
pad_source = max([self.cal_max_len(s_inst) for s_inst in batch_source_ids])
pad_target = max([self.cal_max_len(t_inst) for t_inst in batch_target_ids])
pad_kn = max([self.cal_max_len(k_inst) for k_inst in batch_knowledge_ids])
pad_kn_num = max([len(k_inst) for k_inst in batch_knowledge_ids])
source_pad_ids = [self.pad_data(s_inst, pad_source) for s_inst in batch_source_ids]
target_pad_ids = [self.pad_data(t_inst, pad_target) for t_inst in batch_target_ids]
knowledge_pad_ids = [self.pad_data(k_inst, pad_kn, pad_kn_num)
for k_inst in batch_knowledge_ids]
source_len = [len(inst) for inst in batch_source_ids]
target_len = [len(inst) for inst in batch_target_ids]
kn_len = [[len(term) for term in inst] for inst in batch_knowledge_ids]
kn_len_pad = []
for elem in kn_len:
if len(elem) < pad_kn_num:
elem += [self.vocab_dict['<pad>']] * (pad_kn_num - len(elem))
kn_len_pad.extend(elem)
return_array = [np.array(source_pad_ids).reshape(-1, pad_source), np.array(source_len),
np.array(target_pad_ids).reshape(-1, pad_target), np.array(target_len),
np.array(knowledge_pad_ids).astype("int64").reshape(-1, pad_kn_num, pad_kn),
np.array(kn_len_pad).astype("int64").reshape(-1, pad_kn_num)]
return return_array
def pad_data(self, insts, pad_len, pad_num=-1):
""" padding ids """
insts_pad = []
if isinstance(insts[0], list):
for inst in insts:
inst_pad = inst + [self.vocab_dict['<pad>']] * (pad_len - len(inst))
insts_pad.append(inst_pad)
if len(insts_pad) < pad_num:
insts_pad += [[self.vocab_dict['<pad>']] * pad_len] * (pad_num - len(insts_pad))
else:
insts_pad = insts + [self.vocab_dict['<pad>']] * (pad_len - len(insts))
return insts_pad
def cal_max_len(self, ids):
""" calculate max sequence length """
if isinstance(ids[0], list):
pad_len = max([self.cal_max_len(k) for k in ids])
else:
pad_len = len(ids)
return pad_len
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: __init__.py
"""
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: source/utils/utils.py
"""
import argparse
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def str2bool(v):
""" str2bool """
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
def load_id2str_dict(vocab_file):
""" load id2str dict """
id_dict_array = []
with open(vocab_file, 'r') as fr:
for line in fr:
line = line.strip()
id_dict_array.append(line)
return id_dict_array
def load_str2id_dict(vocab_file):
""" load str2id dict """
words_dict = {}
with open(vocab_file, 'r') as fr:
for line in fr:
word = line.strip()
words_dict[word] = len(words_dict)
return words_dict
def log_softmax(x):
""" log softmax """
t1 = layers.exp(x)
t1 = layers.reduce_sum(t1, dim=-1)
t1 = layers.log(t1)
return layers.elementwise_sub(x, t1, axis=0)
def id_to_text(ids, id_dict_array):
""" convert id seq to str seq """
res = []
for i in ids:
res.append(id_dict_array[i])
return ' '.join(res)
def pad_to_bath_size(src_ids, src_len, trg_ids, trg_len, kn_ids, kn_len, batch_size):
""" pad to bath size for knowledge corpus"""
real_len = src_ids.shape[0]
def pad(old):
""" pad """
old_shape = list(old.shape)
old_shape[0] = batch_size
new_val = np.zeros(old_shape, dtype=old.dtype)
new_val[:real_len] = old
for i in range(real_len, batch_size):
new_val[i] = old[-1]
return new_val
new_src_ids = pad(src_ids)
new_src_len = pad(src_len)
new_trg_ids = pad(trg_ids)
new_trg_len = pad(trg_len)
new_kn_ids = pad(kn_ids)
new_kn_len = pad(kn_len)
return [new_src_ids, new_src_len, new_trg_ids, new_trg_len, new_kn_ids, new_kn_len]
def to_lodtensor(data, seq_lens, place):
""" convert to LoDTensor """
cur_len = 0
lod = [cur_len]
data_array = []
for idx, seq in enumerate(seq_lens):
if seq > 0:
data_array.append(data[idx, :seq])
cur_len += seq
lod.append(cur_len)
else:
data_array.append(np.zeros([1, 1], dtype='int64'))
cur_len += 1
lod.append(cur_len)
flattened_data = np.concatenate(data_array, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
res = fluid.LoDTensor()
res.set(flattened_data, place)
res.set_lod([lod])
return res
def len_to_mask(len_seq, max_len=None):
""" len to mask """
if max_len is None:
max_len = np.max(len_seq)
mask = np.zeros((len_seq.shape[0], max_len), dtype='float32')
for i, l in enumerate(len_seq):
mask[i, :l] = 1.0
return mask
def build_data_feed(data, place,
batch_size=128,
is_training=False,
bow_max_len=30,
pretrain_epoch=False):
""" build data feed """
src_ids, src_len, trg_ids, trg_len, kn_ids, kn_len = data
real_size = src_ids.shape[0]
if src_ids.shape[0] < batch_size:
if not is_training:
src_ids, src_len, trg_ids, trg_len, kn_ids, kn_len = \
pad_to_bath_size(src_ids, src_len, trg_ids, trg_len, kn_ids, kn_len, batch_size)
else:
return None
enc_input = np.expand_dims(src_ids[:, 1: -1], axis=2)
enc_mask = len_to_mask(src_len - 2)
tar_input = np.expand_dims(trg_ids[:, 1: -1], axis=2)
tar_mask = len_to_mask(trg_len - 2)
cue_input = np.expand_dims(kn_ids.reshape((-1, kn_ids.shape[-1]))[:, 1:-1], axis=2)
cue_mask = len_to_mask(kn_len.reshape(-1) - 2)
memory_mask = np.equal(kn_len, 0).astype('float32')
enc_memory_mask = 1.0 - enc_mask
if not is_training:
return {'enc_input': to_lodtensor(enc_input, src_len - 2, place),
'enc_mask': enc_mask,
'cue_input': to_lodtensor(cue_input, kn_len.reshape(-1) - 2, place),
'cue_last_mask': np.not_equal(kn_len.reshape(-1), 0).astype('float32'),
'memory_mask': memory_mask,
'enc_memory_mask': enc_memory_mask,
}, real_size
dec_input = np.expand_dims(trg_ids[:, :-1], axis=2)
dec_mask = len_to_mask(trg_len - 1)
target_label = trg_ids[:, 1:]
target_mask = len_to_mask(trg_len - 1)
bow_label = target_label[:, :-1]
bow_label = np.pad(bow_label, ((0, 0), (0, bow_max_len - bow_label.shape[1])), 'constant', constant_values=(0))
bow_mask = np.pad(np.not_equal(bow_label, 0).astype('float32'), ((0, 0), (0, bow_max_len - bow_label.shape[1])),
'constant', constant_values=(0.0))
if not pretrain_epoch:
kl_and_nll_factor = np.ones([1], dtype='float32')
else:
kl_and_nll_factor = np.zeros([1], dtype='float32')
return {'enc_input': to_lodtensor(enc_input, src_len - 2, place),
'enc_mask': enc_mask,
'cue_input': to_lodtensor(cue_input, kn_len.reshape(-1) - 2, place),
'cue_last_mask': np.not_equal(kn_len.reshape(-1), 0).astype('float32'),
'memory_mask': memory_mask,
'enc_memory_mask': enc_memory_mask,
'tar_input': to_lodtensor(tar_input, trg_len - 2, place),
'bow_label': bow_label,
'bow_mask': bow_mask,
'target_label': target_label,
'target_mask': target_mask,
'dec_input': dec_input,
'dec_mask': dec_mask,
'kl_and_nll_factor': kl_and_nll_factor}
def load_embedding(embedding_file, vocab_file):
""" load pretrain embedding from file """
words_dict = load_str2id_dict(vocab_file)
coverage = 0
print("Building word embeddings from '{}' ...".format(embedding_file))
with open(embedding_file, "r") as f:
num, dim = map(int, f.readline().strip().split())
embeds = [[0] * dim] * len(words_dict)
for line in f:
w, vs = line.rstrip().split(" ", 1)
if w in words_dict:
try:
vs = [float(x) for x in vs.split(" ")]
except Exception:
vs = []
if len(vs) == dim:
embeds[words_dict[w]] = vs
coverage += 1
rate = coverage * 1.0 / len(embeds)
print("{} words have pretrained {}-D word embeddings (coverage: {:.3f})".format( \
coverage, dim, rate))
return np.array(embeds).astype('float32')
def init_embedding(embedding_file, vocab_file, init_scale, shape):
""" init embedding by pretrain file or random """
if embedding_file != "":
try:
emb_np = load_embedding(embedding_file, vocab_file)
except:
print("load init emb file failed", embedding_file)
raise Exception("load embedding file failed")
if emb_np.shape != shape:
print("shape not match", emb_np.shape, shape)
raise Exception("shape not match")
zero_count = 0
for i in range(emb_np.shape[0]):
if np.sum(emb_np[i]) == 0:
zero_count += 1
emb_np[i] = np.random.uniform(-init_scale, init_scale, emb_np.shape[1:]).astype('float32')
else:
print("random init embeding")
emb_np = np.random.uniform(-init_scale, init_scale, shape).astype('float32')
return emb_np
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: build_vocabulary.py
"""
from __future__ import print_function
import sys
import re
from collections import Counter
reload(sys)
sys.setdefaultencoding('utf8')
def tokenize(s):
"""
tokenize
"""
s = re.sub('\d+', '<num>', s).lower()
tokens = s.split(' ')
return tokens
def build_vocabulary(corpus_file, vocab_file,
vocab_size=30004, min_frequency=0,
min_len=1, max_len=500):
"""
build words dict
"""
specials = ["<pad>", "<unk>", "<bos>", "<eos>"]
counter = Counter()
for line in open(corpus_file, 'r'):
src, tgt, knowledge = line.rstrip('\n').split('\t')[:3]
filter_knowledge = []
for sent in knowledge.split('\1'):
filter_knowledge.append(' '.join(sent.split()[:max_len]))
knowledge = ' '.join(filter_knowledge)
src = tokenize(src)
tgt = tokenize(tgt)
knowledge = tokenize(knowledge)
if len(src) < min_len or len(src) > max_len or \
len(tgt) < min_len or len(tgt) > max_len:
continue
counter.update(src + tgt + knowledge)
for tok in specials:
del counter[tok]
words_and_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
words_and_frequencies.sort(key=lambda tup: tup[1], reverse=True)
words_and_frequencies = [[tok, sys.maxint] for tok in specials] + words_and_frequencies
words_and_frequencies = words_and_frequencies[:vocab_size]
fout = open(vocab_file, 'w')
for word, frequency in words_and_frequencies:
if frequency < min_frequency:
break
fout.write(word + '\n')
fout.close()
def main():
"""
main
"""
if len(sys.argv) < 3:
print("Usage: " + sys.argv[0] + " corpus_file vocab_file")
exit()
build_vocabulary(sys.argv[1], sys.argv[2])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: conversation_client.py
"""
from __future__ import print_function
import sys
import socket
reload(sys)
sys.setdefaultencoding('utf8')
SERVER_IP = "127.0.0.1"
SERVER_PORT = 8601
def conversation_client(text):
"""
conversation_client
"""
mysocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
mysocket.connect((SERVER_IP, SERVER_PORT))
mysocket.sendall(text.encode())
result = mysocket.recv(4096).decode()
mysocket.close()
return result
def main():
"""
main
"""
if len(sys.argv) < 2:
print("Usage: " + sys.argv[0] + " eval_file")
exit()
for line in open(sys.argv[1]):
response = conversation_client(line.strip())
print(response)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: conversation_server.py
"""
from __future__ import print_function
import sys
sys.path.append("../")
import socket
from thread import start_new_thread
from tools.conversation_strategy import load
from tools.conversation_strategy import predict
reload(sys)
sys.setdefaultencoding('utf8')
SERVER_IP = "127.0.0.1"
SERVER_PORT = 8601
print("starting conversation server ...")
print("binding socket ...")
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
#Bind socket to local host and port
try:
s.bind((SERVER_IP, SERVER_PORT))
except socket.error as msg:
print("Bind failed. Error Code : " + str(msg[0]) + " Message " + msg[1])
exit()
#Start listening on socket
s.listen(10)
print("bind socket success !")
print("loading model...")
model = load()
print("load model success !")
print("start conversation server success !")
def clientthread(conn, addr):
"""
client thread
"""
logstr = "addr:" + addr[0]+ "_" + str(addr[1])
try:
#Receiving from client
param = conn.recv(4096).decode()
logstr += "\tparam:" + param
if param is not None:
response = predict(model, param.strip())
logstr += "\tresponse:" + response
conn.sendall(response.encode())
conn.close()
print(logstr + "\n")
except Exception as e:
print(logstr + "\n", e)
while True:
conn, addr = s.accept()
start_new_thread(clientthread, (conn, addr))
s.close()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: conversation_strategy.py
"""
from __future__ import print_function
import sys
sys.path.append("../")
import network
from tools.convert_conversation_corpus_to_model_text import preprocessing_for_one_conversation
reload(sys)
sys.setdefaultencoding('utf8')
def load():
"""
load model
"""
return network.load()
def predict(model, text):
"""
predict
"""
model_text, topic_dict = \
preprocessing_for_one_conversation(text.strip(), topic_generalization=True)
if isinstance(model_text, unicode):
model_text = model_text.encode('utf-8')
response = network.predict(model, model_text)
topic_list = sorted(topic_dict.items(), key=lambda item: len(item[1]), reverse=True)
for key, value in topic_list:
response = response.replace(key, value)
return response
def main():
"""
main
"""
generator = load()
for line in sys.stdin:
response = predict(generator, line.strip())
print(response)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: convert_conversation_corpus_to_model_text.py
"""
from __future__ import print_function
import sys
import json
import collections
reload(sys)
sys.setdefaultencoding('utf8')
def preprocessing_for_one_conversation(text,
topic_generalization=False):
"""
preprocessing_for_one_conversation
"""
conversation = json.loads(text.strip(), encoding="utf-8", \
object_pairs_hook=collections.OrderedDict)
goal = conversation["goal"]
knowledge = conversation["knowledge"]
history = conversation["history"]
response = conversation["response"] if "response" in conversation else "null"
topic_a = goal[0][1]
topic_b = goal[0][2]
for i, [s, p, o] in enumerate(knowledge):
if u"领域" == p:
if topic_a == s:
domain_a = o
elif topic_b == s:
domain_b = o
topic_dict = {}
if u"电影" == domain_a:
topic_dict["video_topic_a"] = topic_a
else:
topic_dict["person_topic_a"] = topic_a
if u"电影" == domain_b:
topic_dict["video_topic_b"] = topic_b
else:
topic_dict["person_topic_b"] = topic_b
chat_path_str = ' '.join([' '.join(spo) for spo in goal])
knowledge_str1 = ' '.join([' '.join(spo) for spo in knowledge])
knowledge_str2 = '\1'.join([' '.join(spo) for spo in knowledge])
history_str = ' '.join(history)
src = chat_path_str + " " + knowledge_str1 + " : " + history_str
model_text = '\t'.join([src, response, knowledge_str2])
if topic_generalization:
topic_list = sorted(topic_dict.items(), key=lambda item: len(item[1]), reverse=True)
for key, value in topic_list:
model_text = model_text.replace(value, key)
return model_text, topic_dict
def convert_conversation_corpus_to_model_text(corpus_file, text_file, topic_file, \
topic_generalization=False):
"""
convert_conversation_corpus_to_model_text
"""
fout_text = open(text_file, 'w')
fout_topic = open(topic_file, 'w')
with open(corpus_file, 'r') as f:
for i, line in enumerate(f):
model_text, topic_dict = preprocessing_for_one_conversation(
line.strip(), topic_generalization=topic_generalization)
topic_dict = json.dumps(topic_dict, ensure_ascii=False, encoding="utf-8")
fout_text.write(model_text + "\n")
fout_topic.write(topic_dict + "\n")
fout_text.close()
fout_topic.close()
def main():
"""
main
"""
convert_conversation_corpus_to_model_text(sys.argv[1],
sys.argv[2],
sys.argv[3],
int(sys.argv[4]) > 0)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: convert_result_for_eval.py
"""
from __future__ import print_function
import sys
import json
import collections
reload(sys)
sys.setdefaultencoding('utf8')
def convert_result_for_eval(sample_file, result_file, output_file):
"""
convert_result_for_eval
"""
sample_list = [line.strip() for line in open(sample_file, 'r')]
result_list = [line.strip() for line in open(result_file, 'r')]
assert len(sample_list) == len(result_list)
fout = open(output_file, 'w')
for i, sample in enumerate(sample_list):
sample = json.loads(sample, encoding="utf-8", \
object_pairs_hook=collections.OrderedDict)
response = sample["response"]
fout.write(result_list[i] + "\t" + response + "\n")
fout.close()
def main():
"""
main
"""
convert_result_for_eval(sys.argv[1],
sys.argv[2],
sys.argv[3])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: convert_session_to_sample.py
"""
from __future__ import print_function
import sys
import json
import collections
reload(sys)
sys.setdefaultencoding('utf8')
def convert_session_to_sample(session_file, sample_file):
"""
convert_session_to_sample
"""
fout = open(sample_file, 'w')
with open(session_file, 'r') as f:
for i, line in enumerate(f):
session = json.loads(line.strip(), encoding="utf-8", \
object_pairs_hook=collections.OrderedDict)
conversation = session["conversation"]
for j in range(0, len(conversation), 2):
sample = collections.OrderedDict()
sample["goal"] = session["goal"]
sample["knowledge"] = session["knowledge"]
sample["history"] = conversation[:j]
sample["response"] = conversation[j]
sample = json.dumps(sample, ensure_ascii=False, encoding="utf-8")
fout.write(sample + "\n")
fout.close()
def main():
"""
main
"""
convert_session_to_sample(sys.argv[1], sys.argv[2])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: eval.py
"""
from __future__ import print_function
import sys
import math
from collections import Counter
reload(sys)
sys.setdefaultencoding('utf8')
if len(sys.argv) < 2:
print("Usage: " + sys.argv[0] + " eval_file")
print("eval file format: pred_response \t gold_response")
exit()
def get_dict(tokens, ngram, gdict=None):
"""
get_dict
"""
token_dict = {}
if gdict is not None:
token_dict = gdict
tlen = len(tokens)
for i in range(0, tlen - ngram + 1):
ngram_token = "".join(tokens[i:(i + ngram)])
if token_dict.get(ngram_token) is not None:
token_dict[ngram_token] += 1
else:
token_dict[ngram_token] = 1
return token_dict
def count(pred_tokens, gold_tokens, ngram, result):
"""
count
"""
cover_count, total_count = result
pred_dict = get_dict(pred_tokens, ngram)
gold_dict = get_dict(gold_tokens, ngram)
cur_cover_count = 0
cur_total_count = 0
for token, freq in pred_dict.items():
if gold_dict.get(token) is not None:
gold_freq = gold_dict[token]
cur_cover_count += min(freq, gold_freq)
cur_total_count += freq
result[0] += cur_cover_count
result[1] += cur_total_count
def calc_bp(pair_list):
"""
calc_bp
"""
c_count = 0.0
r_count = 0.0
for pair in pair_list:
pred_tokens, gold_tokens = pair
c_count += len(pred_tokens)
r_count += len(gold_tokens)
bp = 1
if c_count < r_count:
bp = math.exp(1 - r_count / c_count)
return bp
def calc_cover_rate(pair_list, ngram):
"""
calc_cover_rate
"""
result = [0.0, 0.0] # [cover_count, total_count]
for pair in pair_list:
pred_tokens, gold_tokens = pair
count(pred_tokens, gold_tokens, ngram, result)
cover_rate = result[0] / result[1]
return cover_rate
def calc_bleu(pair_list):
"""
calc_bleu
"""
bp = calc_bp(pair_list)
cover_rate1 = calc_cover_rate(pair_list, 1)
cover_rate2 = calc_cover_rate(pair_list, 2)
cover_rate3 = calc_cover_rate(pair_list, 3)
bleu1 = 0
bleu2 = 0
bleu3 = 0
if cover_rate1 > 0:
bleu1 = bp * math.exp(math.log(cover_rate1))
if cover_rate2 > 0:
bleu2 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2)) / 2)
if cover_rate3 > 0:
bleu3 = bp * math.exp((math.log(cover_rate1) + math.log(cover_rate2) + math.log(cover_rate3)) / 3)
return [bleu1, bleu2]
def calc_distinct_ngram(pair_list, ngram):
"""
calc_distinct_ngram
"""
ngram_total = 0.0
ngram_distinct_count = 0.0
pred_dict = {}
for predict_tokens, _ in pair_list:
get_dict(predict_tokens, ngram, pred_dict)
for key, freq in pred_dict.items():
ngram_total += freq
ngram_distinct_count += 1
#if freq == 1:
# ngram_distinct_count += freq
return ngram_distinct_count / ngram_total
def calc_distinct(pair_list):
"""
calc_distinct
"""
distinct1 = calc_distinct_ngram(pair_list, 1)
distinct2 = calc_distinct_ngram(pair_list, 2)
return [distinct1, distinct2]
def calc_f1(data):
"""
calc_f1
"""
golden_char_total = 0.0
pred_char_total = 0.0
hit_char_total = 0.0
for response, golden_response in data:
golden_response = "".join(golden_response).decode("utf8")
response = "".join(response).decode("utf8")
#golden_response = "".join(golden_response)
#response = "".join(response)
common = Counter(response) & Counter(golden_response)
hit_char_total += sum(common.values())
golden_char_total += len(golden_response)
pred_char_total += len(response)
p = hit_char_total / pred_char_total
r = hit_char_total / golden_char_total
f1 = 2 * p * r / (p + r)
return f1
eval_file = sys.argv[1]
sents = []
for line in open(eval_file):
tk = line.strip().split("\t")
if len(tk) < 2:
continue
pred_tokens = tk[0].strip().split(" ")
gold_tokens = tk[1].strip().split(" ")
sents.append([pred_tokens, gold_tokens])
# calc f1
f1 = calc_f1(sents)
# calc bleu
bleu1, bleu2 = calc_bleu(sents)
# calc distinct
distinct1, distinct2 = calc_distinct(sents)
output_str = "F1: %.2f%%\n" % (f1 * 100)
output_str += "BLEU1: %.3f%%\n" % bleu1
output_str += "BLEU2: %.3f%%\n" % bleu2
output_str += "DISTINCT1: %.3f%%\n" % distinct1
output_str += "DISTINCT2: %.3f%%\n" % distinct2
sys.stdout.write(output_str)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
#
# Copyright (c) 2019 Baidu.com, Inc. All Rights Reserved
#
################################################################################
"""
File: topic_materialization.py
"""
from __future__ import print_function
import sys
import json
reload(sys)
sys.setdefaultencoding('utf8')
def topic_materialization(input_file, output_file, topic_file):
"""
topic_materialization
"""
inputs = [line.strip() for line in open(input_file, 'r')]
topics = [line.strip() for line in open(topic_file, 'r')]
assert len(inputs) == len(topics)
fout = open(output_file, 'w')
for i, text in enumerate(inputs):
topic_dict = json.loads(topics[i], encoding="utf-8")
topic_list = sorted(topic_dict.items(), key=lambda item: len(item[1]), reverse=True)
for key, value in topic_list:
text = text.replace(key, value)
fout.write(text + "\n")
fout.close()
def main():
"""
main
"""
topic_materialization(sys.argv[1],
sys.argv[2],
sys.argv[3])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
Knowledge-driven Dialogue
=============================
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
This is a paddlepaddle implementation of retrieval-based model for knowledge-driven dialogue
## Requirements
* cuda=9.0
* cudnn=7.0
* python=2.7
* numpy
* paddlepaddle>=1.3
## Quickstart
### Step 1: Preprocess the data
Put the data of [DuConv](https://ai.baidu.com/broad/subordinate?dataset=duconv) under the data folder and rename them train/dev/test.txt:
```
./data/resource/train.txt
./data/resource/dev.txt
./data/resource/test.txt
```
### Step 2: Train the model
Train model with the following commands.
```bash
sh run_train.sh model_name
```
3 models were supported:
- match: match, input is history and response
- match_kn: match_kn, input is history, response, chat_path, knowledge
- match_kn_gene: match_kn, input is history, response, chat_path, knowledge and generalizes target_a/target_b of goal for all inputs, replaces them with slot mark
### Step 3: Test the Model
Test model with the following commands.
```bash
sh run_test.sh model_name
```
## Note !!!
* The script run_train.sh/run_test.sh shows all the processes including data processing and model training/testing. Be sure to read it carefully and follow it.
* The files in ./data and ./model is just empty file to show the structure of the document.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: args.py
"""
from __future__ import print_function
import six
import argparse
# define argument parser & add common arguments
def base_parser():
parser = argparse.ArgumentParser(description="Arguments for running classifier.")
parser.add_argument(
'--epoch',
type=int,
default=100,
help='Number of epoches for training. (default: %(default)d)')
parser.add_argument(
'--task_name',
type=str,
default='match',
help='task name for training')
parser.add_argument(
'--max_seq_len',
type=int,
default=512,
help='Number of word of the longest seqence. (default: %(default)d)')
parser.add_argument(
'--batch_size',
type=int,
default=8096,
help='Total token number in batch for training. (default: %(default)d)')
parser.add_argument(
'--voc_size',
type=int,
default=14373,
help='Total token number in batch for training. (default: %(default)d)')
parser.add_argument(
'--init_checkpoint',
type=str,
default=None,
help='init checkpoint to resume training from. (default: %(default)s)')
parser.add_argument(
'--save_inference_model_path',
type=str,
default="inference_model",
help='save inference model. (default: %(default)s)')
parser.add_argument(
'--output',
type=str,
default="./output/pred.txt",
help='init checkpoint to resume training from. (default: %(default)s)')
parser.add_argument(
'--learning_rate',
type=float,
default=1e-2,
help='Learning rate used to train with warmup. (default: %(default)f)')
parser.add_argument(
'--weight_decay',
type=float,
default=0.01,
help='Weight decay rate for L2 regularizer. (default: %(default)f)')
parser.add_argument(
'--checkpoints',
type=str,
default="checkpoints",
help='Path to save checkpoints. (default: %(default)s)')
parser.add_argument(
'--vocab_path',
type=str,
default=None,
help='Vocabulary path. (default: %(default)s)')
parser.add_argument(
'--data_dir',
type=str,
default="./real_data",
help='Path of training data. (default: %(default)s)')
parser.add_argument(
'--skip_steps',
type=int,
default=10,
help='The steps interval to print loss. (default: %(default)d)')
parser.add_argument(
'--save_steps',
type=int,
default=10000,
help='The steps interval to save checkpoints. (default: %(default)d)')
parser.add_argument(
'--validation_steps',
type=int,
default=1000,
help='The steps interval to evaluate model performance on validation '
'set. (default: %(default)d)')
parser.add_argument(
'--use_cuda', action='store_true', help='If set, use GPU for training.')
parser.add_argument(
'--use_fast_executor',
action='store_true',
help='If set, use fast parallel executor (in experiment).')
parser.add_argument(
'--do_lower_case',
type=bool,
default=True,
choices=[True, False],
help="Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
parser.add_argument(
'--warmup_proportion',
type=float,
default=0.1,
help='proportion warmup. (default: %(default)f)')
args = parser.parse_args()
return args
def print_arguments(args):
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: interact.py
"""
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from source.inputters.data_provider import load_dict
from source.inputters.data_provider import MatchProcessor
from source.inputters.data_provider import preprocessing_for_one_line
import numpy as np
load_dict("./dict/gene.dict")
def load_model():
"""
load model function
"""
main_program = fluid.default_main_program()
#place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(framework.default_startup_program())
path = "./models/inference_model"
[inference_program, feed_dict, fetch_targets] = \
fluid.io.load_inference_model(dirname=path, executor=exe)
model_handle = [exe, inference_program, feed_dict, fetch_targets, place]
return model_handle
def predict(model_handle, text, task_name):
"""
predict score function
"""
exe = model_handle[0]
inference_program = model_handle[1]
feed_dict = model_handle[2]
fetch_targets = model_handle[3]
place = model_handle[4]
data = preprocessing_for_one_line(text, MatchProcessor.get_labels(), \
task_name, max_seq_len=256)
context_ids = [elem[0] for elem in data]
context_pos_ids = [elem[1] for elem in data]
context_segment_ids = [elem[2] for elem in data]
context_attn_mask = [elem[3] for elem in data]
labels_ids = [[1]]
if 'kn' in task_name:
kn_ids = [elem[4] for elem in data]
kn_ids = fluid.create_lod_tensor(kn_ids, [[len(kn_ids[0])]], place)
context_next_sent_index = [elem[5] for elem in data]
results = exe.run(inference_program,
feed={feed_dict[0]: np.array(context_ids),
feed_dict[1]: np.array(context_pos_ids),
feed_dict[2]: np.array(context_segment_ids),
feed_dict[3]: np.array(context_attn_mask),
feed_dict[4]: kn_ids,
feed_dict[5]: np.array(labels_ids),
feed_dict[6]: np.array(context_next_sent_index)},
fetch_list=fetch_targets)
else:
context_next_sent_index = [elem[4] for elem in data]
results = exe.run(inference_program,
feed={feed_dict[0]: np.array(context_ids),
feed_dict[1]: np.array(context_pos_ids),
feed_dict[2]: np.array(context_segment_ids),
feed_dict[3]: np.array(context_attn_mask),
feed_dict[4]: np.array(labels_ids),
feed_dict[5]: np.array(context_next_sent_index)},
fetch_list=fetch_targets)
score = results[0][0][1]
return score
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: predict.py
Load checkpoint of running classifier to do prediction and save inference model.
"""
import os
import time
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
import source.inputters.data_provider as reader
import multiprocessing
from train import create_model
from args import base_parser
from args import print_arguments
from source.utils.utils import init_pretraining_params
def main(args):
task_name = args.task_name.lower()
processor = reader.MatchProcessor(data_dir=args.data_dir,
task_name=task_name,
vocab_path=args.vocab_path,
max_seq_len=args.max_seq_len,
do_lower_case=args.do_lower_case)
num_labels = len(processor.get_labels())
infer_data_generator = processor.data_generator(
batch_size=args.batch_size,
phase='test',
epoch=1,
shuffle=False)
num_test_examples = processor.get_num_examples(phase='test')
main_program = fluid.default_main_program()
feed_order, loss, probs, accuracy, num_seqs = create_model(
args,
num_labels=num_labels,
is_prediction=True)
if args.use_cuda:
place = fluid.CUDAPlace(0)
dev_count = fluid.core.get_cuda_device_count()
else:
place = fluid.CPUPlace()
dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
exe = fluid.Executor(place)
exe.run(framework.default_startup_program())
if args.init_checkpoint:
init_pretraining_params(exe, args.init_checkpoint, main_program)
feed_list = [
main_program.global_block().var(var_name) for var_name in feed_order
]
feeder = fluid.DataFeeder(feed_list, place)
out_scores = open(args.output, 'w')
for batch_id, data in enumerate(infer_data_generator()):
results = exe.run(
fetch_list=[probs],
feed=feeder.feed(data),
return_numpy=True)
for elem in results[0]:
out_scores.write(str(elem[1]) + '\n')
out_scores.close()
if args.save_inference_model_path:
model_path = args.save_inference_model_path
fluid.io.save_inference_model(
model_path,
feed_order, probs,
exe,
main_program=main_program)
if __name__ == '__main__':
args = base_parser()
print_arguments(args)
main(args)
#!/bin/bash
# set gpu id to use
export CUDA_VISIBLE_DEVICES=1
# task_name can select from ["match", "match_kn", "match_kn_gene"]
# match task: do not use knowledge info (goal and knowledge) for retrieval model
# match_kn task: use knowledge info (goal and knowledge) for retrieval model
# match_kn_gene task: 1) use knowledge info (goal and knowledge) for retrieval model;
# 2) generalizes target_a/target_b of goal, replaces them with slot mark
# more information about generalization in match_kn_gene,
# you can refer to ./tools/convert_conversation_corpus_to_model_text.py
TASK_NAME=$1
if [ "$TASK_NAME" = "match" ]
then
DICT_NAME="./dict/char.dict"
USE_KNOWLEDGE=0
TOPIC_GENERALIZATION=0
elif [ "$TASK_NAME" = "match_kn" ]
then
DICT_NAME="./dict/char.dict"
USE_KNOWLEDGE=1
TOPIC_GENERALIZATION=0
elif [ "$TASK_NAME" = "match_kn_gene" ]
then
DICT_NAME="./dict/gene.dict"
USE_KNOWLEDGE=1
TOPIC_GENERALIZATION=1
else
echo "task name error, should be match|match_kn|match_kn_gene"
fi
# in predict stage, FOR_PREDICT=1
FOR_PREDICT=1
# put all data set that used and generated for testing under this folder: INPUT_PATH
# for more details, please refer to the following data processing instructions
INPUT_PATH="./data"
# put the model files needed for testing under this folder: OUTPUT_PATH
OUTPUT_PATH="./models"
# set python path according to your actual environment
PYTHON_PATH="python"
# in test stage, you can eval dev.txt or test.txt
# the "dev.txt" and "test.txt" are the original data of DuConv and
# need to be placed in this folder: INPUT_PATH/resource/
# the following preprocessing will generate the actual data needed for model testing
# after testing, you can run eval.py to get the final eval score if the original data have answer
# DATA_TYPE = "dev" or "test"
DATA_TYPE="dev"
# candidate set, construct in train stage
candidate_set_file=${INPUT_PATH}/candidate_set.txt
# ensure that each file is in the correct path
# 1. put the data of DuConv under this folder: INPUT_PATH/resource/
# - the data provided consists of three parts: train.txt dev.txt test.txt
# - the train.txt and dev.txt are session data, the test.txt is sample data
# - in test stage, we just use the dev.txt or test.txt
# 2. the sample data extracted from session data is in this folder: INPUT_PATH/resource/
# 3. the candidate data constructed from sample data is in this folder: INPUT_PATH/resource/
# 4. the text file required by the model is in this folder: INPUT_PATH
corpus_file=${INPUT_PATH}/resource/${DATA_TYPE}.txt
sample_file=${INPUT_PATH}/resource/sample.${DATA_TYPE}.txt
candidate_file=${INPUT_PATH}/resource/candidate.${DATA_TYPE}.txt
text_file=${INPUT_PATH}/test.txt
score_file=./output/score.txt
predict_file=./output/predict.txt
# step 1: if eval dev.txt, firstly have to convert session data to sample data
# if eval test.txt, we can use original test.txt of DuConv directly.
if [ "${DATA_TYPE}"x = "test"x ]; then
sample_file=${corpus_file}
else
${PYTHON_PATH} ./tools/convert_session_to_sample.py ${corpus_file} ${sample_file}
fi
# step 2: construct candidate for sample data
${PYTHON_PATH} ./tools/construct_candidate.py ${sample_file} ${candidate_set_file} ${candidate_file} 10
# step 3: convert sample data with candidates to text data required by the model
${PYTHON_PATH} ./tools/convert_conversation_corpus_to_model_text.py ${candidate_file} ${text_file} ${USE_KNOWLEDGE} ${TOPIC_GENERALIZATION} ${FOR_PREDICT}
# inference_model can used for interact.py
inference_model="./models/inference_model"
# step 4: predict score by model
$PYTHON_PATH -u predict.py --task_name ${TASK_NAME} \
--use_cuda \
--batch_size 10 \
--init_checkpoint ${OUTPUT_PATH}/50 \
--data_dir ${INPUT_PATH} \
--vocab_path ${DICT_NAME} \
--save_inference_model_path ${inference_model} \
--max_seq_len 128 \
--output ${score_file}
# step 5: extract predict utterance by candidate_file and score_file
# if the original file has answers, the predict_file format is "predict \t gold \n predict \t gold \n ......"
# if the original file not has answers, the predict_file format is "predict \n predict \n predict \n predict \n ......"
${PYTHON_PATH} ./tools/extract_predict_utterance.py ${candidate_file} ${score_file} ${predict_file}
# step 6: if the original file has answers, you can run the following command to get result
# if the original file not has answers, you can upload the ./output/test.result.final
# to the website(https://ai.baidu.com/broad/submission?dataset=duconv) to get the official automatic evaluation
${PYTHON_PATH} ./tools/eval.py ${predict_file}
#!/bin/bash
# set gpu id to use
export CUDA_VISIBLE_DEVICES=0
# task_name can select from ["match", "match_kn", "match_kn_gene"]
# match task: do not use knowledge info (goal and knowledge) for retrieval model
# match_kn task: use knowledge info (goal and knowledge) for retrieval model
# match_kn_gene task: 1) use knowledge info (goal and knowledge) for retrieval model;
# 2) generalizes target_a/target_b of goal, replaces them with slot mark
# more information about generalization in match_kn_gene,
# you can refer to ./tools/convert_conversation_corpus_to_model_text.py
TASK_NAME=$1
if [ "$TASK_NAME" = "match" ]
then
DICT_NAME="./dict/char.dict"
USE_KNOWLEDGE=0
TOPIC_GENERALIZATION=0
elif [ "$TASK_NAME" = "match_kn" ]
then
DICT_NAME="./dict/char.dict"
USE_KNOWLEDGE=1
TOPIC_GENERALIZATION=0
elif [ "$TASK_NAME" = "match_kn_gene" ]
then
DICT_NAME="./dict/gene.dict"
USE_KNOWLEDGE=1
TOPIC_GENERALIZATION=1
else
echo "task name error, should be match|match_kn|match_kn_gene"
fi
# in train stage, FOR_PREDICT=0
FOR_PREDICT=0
# put all data set that used and generated for training under this folder: INPUT_PATH
# for more details, please refer to the following data processing instructions
INPUT_PATH="./data"
# put the model file that saved in each stage under this folder: OUTPUT_PATH
OUTPUT_PATH="./models"
# set python path according to your actual environment
PYTHON_PATH="python"
# in train stage, use "train.txt" to train model, and use "dev.txt" to eval model
# the "train.txt" and "dev.txt" are the original data of DuConv and
# need to be placed in this folder: INPUT_PATH/resource/
# the following preprocessing will generate the actual data needed for model training
# DATA_TYPE = "train" or "dev"
DATA_TYPE=("train" "dev")
# candidate set
candidate_set_file=${INPUT_PATH}/candidate_set.txt
# data preprocessing
for ((i=0; i<${#DATA_TYPE[*]}; i++))
do
# ensure that each file is in the correct path
# 1. put the data of DuConv under this folder: INPUT_PATH/resource/
# - the data provided consists of three parts: train.txt dev.txt test.txt
# - the train.txt and dev.txt are session data, the test.txt is sample data
# - in train stage, we just use the train.txt and dev.txt
# 2. the sample data extracted from session data is in this folder: INPUT_PATH/resource/
# 3. the candidate data constructed from sample data is in this folder: INPUT_PATH/resource/
# 4. the text file required by the model is in this folder: INPUT_PATH
corpus_file=${INPUT_PATH}/resource/${DATA_TYPE[$i]}.txt
sample_file=${INPUT_PATH}/resource/sample.${DATA_TYPE[$i]}.txt
candidate_file=${INPUT_PATH}/resource/candidate.${DATA_TYPE[$i]}.txt
text_file=${INPUT_PATH}/${DATA_TYPE[$i]}.txt
# step 1: build candidate set from session data for negative training cases and predicting candidates
if [ "${DATA_TYPE[$i]}"x = "train"x ]; then
${PYTHON_PATH} ./tools/build_candidate_set_from_corpus.py ${corpus_file} ${candidate_set_file}
fi
# step 2: firstly have to convert session data to sample data
${PYTHON_PATH} ./tools/convert_session_to_sample.py ${corpus_file} ${sample_file}
# step 3: construct candidate for sample data
${PYTHON_PATH} ./tools/construct_candidate.py ${sample_file} ${candidate_set_file} ${candidate_file} 9
# step 4: convert sample data with candidates to text data required by the model
${PYTHON_PATH} ./tools/convert_conversation_corpus_to_model_text.py ${candidate_file} ${text_file} ${USE_KNOWLEDGE} ${TOPIC_GENERALIZATION} ${FOR_PREDICT}
# step 5: build dict from the training data, here we build character dict for model
if [ "${DATA_TYPE[$i]}"x = "train"x ]; then
${PYTHON_PATH} ./tools/build_dict.py ${text_file} ${DICT_NAME}
fi
done
# step 5: train model, you can find the model file in OUTPUT_PATH after training
$PYTHON_PATH -u train.py --task_name ${TASK_NAME} \
--use_cuda \
--batch_size 128 \
--data_dir ${INPUT_PATH} \
--vocab_path ${DICT_NAME} \
--checkpoints ${OUTPUT_PATH} \
--save_steps 1000 \
--weight_decay 0.01 \
--warmup_proportion 0.1 \
--validation_steps 1000000 \
--skip_steps 100 \
--learning_rate 0.1 \
--epoch 30 \
--max_seq_len 256
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: transformer.py
"""
from functools import partial
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
def multi_head_attention(queries,
keys,
values,
attn_bias,
d_key,
d_value,
d_model,
n_head=1,
dropout_rate=0.,
cache=None,
name='multi_head_att'):
"""
Multi-Head Attention. Note that attn_bias is added to the logit before
computing softmax activiation to mask certain selected positions so that
they will not considered in attention weights.
"""
keys = queries if keys is None else keys
values = keys if values is None else values
if not (len(queries.shape) == len(keys.shape) == len(values.shape) == 3):
raise ValueError(
"Inputs: quries, keys and values should all be 3-D tensors.")
def __compute_qkv(queries, keys, values, n_head, d_key, d_value):
"""
Add linear projection to queries, keys, and values.
"""
q = layers.fc(input=queries,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=name + '_query_fc.w_0',
bias_attr=name + '_query_fc.b_0')
k = layers.fc(input=keys,
size=d_key * n_head,
num_flatten_dims=2,
param_attr=name + '_key_fc.w_0',
bias_attr=name + '_key_fc.b_0')
v = layers.fc(input=values,
size=d_value * n_head,
num_flatten_dims=2,
param_attr=name + '_value_fc.w_0',
bias_attr=name + '_value_fc.b_0')
return q, k, v
def __split_heads(x, n_head):
"""
Reshape the last dimension of inpunt tensor x so that it becomes two
dimensions and then transpose. Specifically, input a tensor with shape
[bs, max_sequence_length, n_head * hidden_dim] then output a tensor
with shape [bs, n_head, max_sequence_length, hidden_dim].
"""
hidden_size = x.shape[-1]
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
reshaped = layers.reshape(
x=x, shape=[0, 0, n_head, hidden_size // n_head], inplace=True)
# permuate the dimensions into:
# [batch_size, n_head, max_sequence_len, hidden_size_per_head]
return layers.transpose(x=reshaped, perm=[0, 2, 1, 3])
def __combine_heads(x):
"""
Transpose and then reshape the last two dimensions of inpunt tensor x
so that it becomes one dimension, which is reverse to __split_heads.
"""
if len(x.shape) == 3: return x
if len(x.shape) != 4:
raise ValueError("Input(x) should be a 4-D Tensor.")
trans_x = layers.transpose(x, perm=[0, 2, 1, 3])
# The value 0 in shape attr means copying the corresponding dimension
# size of the input as the output dimension size.
return layers.reshape(
x=trans_x,
shape=[0, 0, trans_x.shape[2] * trans_x.shape[3]],
inplace=True)
def scaled_dot_product_attention(q, k, v, attn_bias, d_key, dropout_rate):
"""
Scaled Dot-Product Attention
"""
scaled_q = layers.scale(x=q, scale=d_key ** -0.5)
product = layers.matmul(x=scaled_q, y=k, transpose_y=True)
if attn_bias:
product += attn_bias
weights = layers.softmax(product)
if dropout_rate:
weights = layers.dropout(
weights,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.matmul(weights, v)
return out
q, k, v = __compute_qkv(queries, keys, values, n_head, d_key, d_value)
if cache is not None: # use cache and concat time steps
# Since the inplace reshape in __split_heads changes the shape of k and
# v, which is the cache input for next time step, reshape the cache
# input from the previous time step first.
k = cache["k"] = layers.concat(
[layers.reshape(
cache["k"], shape=[0, 0, d_model]), k], axis=1)
v = cache["v"] = layers.concat(
[layers.reshape(
cache["v"], shape=[0, 0, d_model]), v], axis=1)
q = __split_heads(q, n_head)
k = __split_heads(k, n_head)
v = __split_heads(v, n_head)
ctx_multiheads = scaled_dot_product_attention(q, k, v, attn_bias, d_key,
dropout_rate)
out = __combine_heads(ctx_multiheads)
# Project back to the model size.
proj_out = layers.fc(input=out,
size=d_model,
num_flatten_dims=2,
param_attr=name + '_output_fc.w_0',
bias_attr=name + '_output_fc.b_0')
return proj_out
def positionwise_feed_forward(x,
d_inner_hid,
d_hid,
dropout_rate,
hidden_act,
name='ffn'):
"""
Position-wise Feed-Forward Networks.
This module consists of two linear transformations with a ReLU activation
in between, which is applied to each position separately and identically.
"""
hidden = layers.fc(input=x,
size=d_inner_hid,
num_flatten_dims=2,
act=hidden_act,
param_attr=name + '_fc_0.w_0',
bias_attr=name + '_fc_0.b_0')
if dropout_rate:
hidden = layers.dropout(
hidden,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
out = layers.fc(input=hidden,
size=d_hid,
num_flatten_dims=2,
param_attr=name + '_fc_1.w_0',
bias_attr=name + '_fc_1.b_0')
return out
def pre_post_process_layer(prev_out, out, process_cmd, dropout_rate=0.,
name=''):
"""
Add residual connection, layer normalization and droput to the out tensor
optionally according to the value of process_cmd.
This will be used before or after multi-head attention and position-wise
feed-forward networks.
"""
for cmd in process_cmd:
if cmd == "a": # add residual connection
out = out + prev_out if prev_out else out
elif cmd == "n": # add layer normalization
out = layers.layer_norm(
out,
begin_norm_axis=len(out.shape) - 1,
param_attr=fluid.ParamAttr(
name=name + '_layer_norm_scale',
initializer=fluid.initializer.Constant(1.)),
bias_attr=fluid.ParamAttr(
name=name + '_layer_norm_bias',
initializer=fluid.initializer.Constant(0.)))
elif cmd == "d": # add dropout
if dropout_rate:
out = layers.dropout(
out,
dropout_prob=dropout_rate,
dropout_implementation="upscale_in_train",
is_test=False)
return out
pre_process_layer = partial(pre_post_process_layer, None)
post_process_layer = pre_post_process_layer
def encoder_layer(enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
name=''):
"""The encoder layers that can be stacked to form a deep encoder.
This module consits of a multi-head (self) attention followed by
position-wise feed-forward networks and both the two components companied
with the post_process_layer to add residual connection, layer normalization
and droput.
"""
attn_output = multi_head_attention(
pre_process_layer(
enc_input,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_att'),
None,
None,
attn_bias,
d_key,
d_value,
d_model,
n_head,
attention_dropout,
name=name + '_multi_head_att')
attn_output = post_process_layer(
enc_input,
attn_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_att')
ffd_output = positionwise_feed_forward(
pre_process_layer(
attn_output,
preprocess_cmd,
prepostprocess_dropout,
name=name + '_pre_ffn'),
d_inner_hid,
d_model,
relu_dropout,
hidden_act,
name=name + '_ffn')
return post_process_layer(
attn_output,
ffd_output,
postprocess_cmd,
prepostprocess_dropout,
name=name + '_post_ffn')
def encoder(enc_input,
attn_bias,
n_layer,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd="n",
postprocess_cmd="da",
name=''):
"""
The encoder is composed of a stack of identical layers returned by calling
encoder_layer.
"""
for i in range(n_layer):
enc_output = encoder_layer(
enc_input,
attn_bias,
n_head,
d_key,
d_value,
d_model,
d_inner_hid,
prepostprocess_dropout,
attention_dropout,
relu_dropout,
hidden_act,
preprocess_cmd,
postprocess_cmd,
name=name + '_layer_' + str(i))
enc_input = enc_output
enc_output = pre_process_layer(
enc_output, preprocess_cmd, prepostprocess_dropout, name="post_encoder")
return enc_output
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: data_provider.py
"""
import re
import os
import types
import csv
import random
import numpy as np
VOC_DICT = {}
def load_dict(vocab_dict):
"""
load vocabulary dict
"""
idx = 0
for line in open(vocab_dict):
line = line.strip()
VOC_DICT[line] = idx
idx += 1
return VOC_DICT
def prepare_batch_data(insts,
task_name,
max_len=128,
return_attn_bias=True,
return_max_len=True,
return_num_token=False):
"""
generate self attention mask, [shape: batch_size * max_len * max_len]
"""
batch_context_ids = [inst[0] for inst in insts]
batch_context_pos_ids = [inst[1] for inst in insts]
batch_segment_ids = [inst[2] for inst in insts]
batch_label_ids = [[inst[3]] for inst in insts]
labels_list = batch_label_ids
context_id, next_sent_context_index, context_attn_bias = \
pad_batch_data(batch_context_ids, pad_idx=0, max_len=max_len, \
return_next_sent_pos=True, return_attn_bias=True)
context_pos_id = pad_batch_data(
batch_context_pos_ids, pad_idx=0, max_len=max_len, return_pos=False, return_attn_bias=False)
context_segment_id = pad_batch_data(
batch_segment_ids, pad_idx=0, max_len=max_len, return_pos=False, return_attn_bias=False)
if 'kn' in task_name:
batch_kn_ids = [inst[4] for inst in insts]
kn_id = pad_bath_kn_data(batch_kn_ids, pad_idx=0, max_len=max_len)
out_list = []
for i in range(len(insts)):
if 'kn' in task_name:
out = [context_id[i], context_pos_id[i], context_segment_id[i], context_attn_bias[i], \
kn_id[i], labels_list[i], next_sent_context_index[i]]
else:
out = [context_id[i], context_pos_id[i], context_segment_id[i], \
context_attn_bias[i], labels_list[i], next_sent_context_index[i]]
out_list.append(out)
return out_list
def pad_bath_kn_data(insts,
pad_idx=0,
max_len=128):
kn_list = []
for inst in insts:
inst = inst[0: min(max_len, len(inst))]
kn_list.append(inst)
return kn_list
def pad_batch_data(insts,
pad_idx=0,
max_len=128,
return_pos=False,
return_next_sent_pos=False,
return_attn_bias=False,
return_max_len=False,
return_num_token=False):
"""
Pad the instances to the max sequence length in batch, and generate the
corresponding position data and attention bias.
"""
return_list = []
inst_data = np.array(
[inst + list([pad_idx] * (max_len - len(inst))) for inst in insts])
return_list += [inst_data.astype("int64").reshape([-1, max_len, 1])]
if return_next_sent_pos:
batch_size = inst_data.shape[0]
max_seq_len = inst_data.shape[1]
next_sent_index = np.array(
range(0, batch_size * max_seq_len, max_seq_len)).astype(
"int64").reshape(-1, 1)
return_list += [next_sent_index]
if return_pos:
inst_pos = np.array([
list(range(0, len(inst))) + [pad_idx] * (max_len - len(inst))
for inst in insts])
return_list += [inst_pos.astype("int64").reshape([-1, max_len, 1])]
if return_attn_bias:
slf_attn_bias_data = np.array([[0] * len(inst) + [-1e9] *
(max_len - len(inst)) for inst in insts])
slf_attn_bias_data = np.tile(
slf_attn_bias_data.reshape([-1, 1, max_len]), [1, max_len, 1])
return_list += [slf_attn_bias_data.astype("float32")]
if return_max_len:
return_list += [max_len]
if return_num_token:
num_token = 0
for inst in insts:
num_token += len(inst)
return_list += [num_token]
return return_list if len(return_list) > 1 else return_list[0]
def preprocessing_for_one_line(line, labels, task_name, max_seq_len=256):
"""
process text to model inputs
"""
line = line.rstrip('\n').split('\t')
label_text = line[0]
context_text = line[1]
response_text = line[2]
if 'kn' in task_name:
kn_text = "%s [SEP] %s" % (line[3], line[4])
else:
kn_text = None
example = InputExample(guid=0, \
context_text=context_text, \
response_text=response_text, \
kn_text=kn_text, \
label_text=label_text)
feature = convert_single_example(0, example, labels, max_seq_len)
instance = [feature.context_ids, feature.context_pos_ids, \
feature.segment_ids, feature.label_ids, feature.kn_ids]
batch_data = prepare_batch_data([instance],
task_name,
max_len=max_seq_len,
return_attn_bias=True,
return_max_len=False,
return_num_token=False)
return batch_data
class DataProcessor(object):
"""Base class for data converters for sequence classification data sets."""
def __init__(self, data_dir, task_name, vocab_path, max_seq_len, do_lower_case):
self.data_dir = data_dir
self.max_seq_len = max_seq_len
self.task_name = task_name
self.current_train_example = -1
self.num_examples = {'train': -1, 'dev': -1, 'test': -1}
self.current_train_epoch = -1
VOC_DICT = load_dict(vocab_path)
def get_train_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the train set."""
raise NotImplementedError()
def get_dev_examples(self, data_dir):
"""Gets a collection of `InputExample`s for the dev set."""
raise NotImplementedError()
def get_test_examples(self, data_dir):
"""Gets a collection of `InputExample`s for prediction."""
raise NotImplementedError()
@classmethod
def get_labels(self):
"""Gets the list of labels for this data set."""
raise NotImplementedError()
def convert_example(self, index, example, labels, max_seq_len):
"""Converts a single `InputExample` into a single `InputFeatures`."""
feature = convert_single_example(index, example, labels, max_seq_len)
return feature
def generate_batch_data(self,
batch_data,
voc_size=-1,
mask_id=-1,
return_attn_bias=True,
return_max_len=False,
return_num_token=False):
return prepare_batch_data(
batch_data,
self.task_name,
self.max_seq_len,
return_attn_bias=True,
return_max_len=False,
return_num_token=False)
@classmethod
def _read_data(cls, input_file):
"""Reads a tab separated value file."""
with open(input_file, "r") as f:
lines = []
for line in f:
line = line.rstrip('\n').split('\t')
lines.append(line)
return lines
def get_num_examples(self, phase):
"""Get number of examples for train, dev or test."""
if phase not in ['train', 'dev', 'test']:
raise ValueError("Unknown phase, which should be in ['train', 'dev', 'test'].")
return self.num_examples[phase]
def get_train_progress(self):
"""Gets progress for training phase."""
return self.current_train_example, self.current_train_epoch
def data_generator(self,
batch_size,
phase='train',
epoch=1,
shuffle=False):
"""
Generate data for train, dev or test.
"""
if phase == 'train':
examples = self.get_train_examples(self.data_dir)
self.num_examples['train'] = len(examples)
elif phase == 'dev':
examples = self.get_dev_examples(self.data_dir)
self.num_examples['dev'] = len(examples)
elif phase == 'test':
examples = self.get_test_examples(self.data_dir)
self.num_examples['test'] = len(examples)
else:
raise ValueError("Unknown phase, which should be in ['train', 'dev', 'test'].")
def instance_reader():
for epoch_index in range(epoch):
if shuffle:
random.shuffle(examples)
if phase == 'train':
self.current_train_epoch = epoch_index
for (index, example) in enumerate(examples):
if phase == 'train':
self.current_train_example = index + 1
feature = self.convert_example(
index, example, self.get_labels(), self.max_seq_len)
if 'kn' in self.task_name:
instance = [feature.context_ids, feature.context_pos_ids, \
feature.segment_ids, feature.label_ids, feature.kn_ids]
else:
instance = [feature.context_ids, feature.context_pos_ids, \
feature.segment_ids, feature.label_ids]
yield instance
def batch_reader(reader, batch_size):
batch = []
for instance in reader():
if len(batch) < batch_size:
batch.append(instance)
else:
yield batch
batch = [instance]
if len(batch) > 0:
yield batch
def wrapper():
for batch_data in batch_reader(instance_reader, batch_size):
batch_data = self.generate_batch_data(
batch_data,
voc_size=-1,
mask_id=-1,
return_attn_bias=True,
return_max_len=False,
return_num_token=False)
yield batch_data
return wrapper
class InputExample(object):
"""A single training/test example"""
def __init__(self, guid, context_text, response_text, kn_text, label_text):
self.guid = guid
self.context_text = context_text
self.response_text = response_text
self.kn_text = kn_text
self.label_text = label_text
class InputFeatures(object):
"""input features datas"""
def __init__(self, context_ids, context_pos_ids, segment_ids, kn_ids, label_ids):
self.context_ids = context_ids
self.context_pos_ids = context_pos_ids
self.segment_ids = segment_ids
self.kn_ids = kn_ids
self.label_ids = label_ids
class MatchProcessor(DataProcessor):
"""Processor for the Match data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_data(os.path.join(data_dir, "train.txt")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_data(os.path.join(data_dir, "dev.txt")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_data(os.path.join(data_dir, "test.txt")), "test")
@classmethod
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
context_text = line[1]
label_text = line[0]
response_text = line[2]
if 'kn' in self.task_name:
kn_text = "%s [SEP] %s" % (line[3], line[4])
else:
kn_text = None
examples.append(
InputExample(
guid=guid, context_text=context_text, response_text=response_text, \
kn_text=kn_text, label_text=label_text))
return examples
def convert_tokens_to_ids(tokens):
"""
convert input ids
"""
ids = []
for token in tokens:
if token in VOC_DICT:
ids.append(VOC_DICT[token])
else:
ids.append(VOC_DICT['[UNK]'])
return ids
def convert_single_example(ex_index, example, label_list, max_seq_length):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list):
label_map[label] = i
if example.context_text:
tokens_context = example.context_text
tokens_context = tokens_context.split()
else:
tokens_context = []
if example.response_text:
tokens_response = example.response_text
tokens_response = tokens_response.split()
else:
tokens_response = []
if example.kn_text:
tokens_kn = example.kn_text
tokens_kn = tokens_kn.split()
tokens_kn = tokens_kn[0: min(len(tokens_kn), max_seq_length)]
else:
tokens_kn = []
tokens_response = tokens_response[0: min(50, len(tokens_response))]
if len(tokens_context) > max_seq_length - len(tokens_response) - 3:
tokens_context = tokens_context[len(tokens_context) \
+ len(tokens_response) - max_seq_length + 3:]
context_tokens = []
segment_ids = []
context_tokens.append("[CLS]")
segment_ids.append(0)
context_tokens.extend(tokens_context)
segment_ids.extend([0] * len(tokens_context))
context_tokens.append("[SEP]")
segment_ids.append(0)
context_tokens.extend(tokens_response)
segment_ids.extend([1] * len(tokens_response))
context_tokens.append("[SEP]")
segment_ids.append(1)
context_ids = convert_tokens_to_ids(context_tokens)
context_pos_ids = list(range(len(context_ids)))
label_ids = label_map[example.label_text]
if tokens_kn:
kn_ids = convert_tokens_to_ids(tokens_kn)
else:
kn_ids = []
feature = InputFeatures(
context_ids=context_ids,
context_pos_ids=context_pos_ids,
segment_ids=segment_ids,
kn_ids = kn_ids,
label_ids=label_ids)
#if ex_index < 5:
# print("*** Example ***")
# print("guid: %s" % (example.guid))
# print("context tokens: %s" % " ".join(context_tokens))
# print("context_ids: %s" % " ".join([str(x) for x in context_ids]))
# print("context_pos_ids: %s" % " ".join([str(x) for x in context_pos_ids]))
# print("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
# print("kn_ids: %s" % " ".join([str(x) for x in kn_ids]))
# print("label: %s (id = %d)" % (example.label_text, label_ids))
return feature
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: retrieval_model.py
"""
import six
import json
import numpy as np
import paddle.fluid as fluid
from source.encoders.transformer import encoder, pre_process_layer
class RetrievalModel(object):
def __init__(self,
context_ids,
context_pos_ids,
context_segment_ids,
context_attn_mask,
kn_ids,
emb_size=1024,
n_layer=12,
n_head=1,
voc_size=10005,
max_position_seq_len=512,
sent_types=2,
hidden_act='relu',
prepostprocess_dropout=0.1,
attention_dropout=0.1,
weight_sharing=True):
self._emb_size = emb_size
self._n_layer = n_layer
self._n_head = n_head
self._voc_size = voc_size
self._sent_types = sent_types
self._max_position_seq_len = max_position_seq_len
self._hidden_act = hidden_act
self._weight_sharing = weight_sharing
self._prepostprocess_dropout = prepostprocess_dropout
self._attention_dropout = attention_dropout
self._context_emb_name = "context_word_embedding"
self._memory_emb_name = "memory_word_embedding"
self._context_pos_emb_name = "context_pos_embedding"
self._context_segment_emb_name = "context_segment_embedding"
if kn_ids:
self._memory_emb_name = "memory_word_embedding"
self._build_model(context_ids, context_pos_ids, \
context_segment_ids, context_attn_mask, kn_ids)
def _build_memory_network(self, kn_ids, rnn_hidden_size=128):
kn_emb_out = fluid.layers.embedding(
input=kn_ids,
size=[self._voc_size, self._emb_size],
dtype='float32')
para_attr = fluid.ParamAttr(initializer=fluid.initializer.Normal(0.0, 0.02))
bias_attr = fluid.ParamAttr(
initializer=fluid.initializer.Normal(0.0, 0.02))
fc_fw = fluid.layers.fc(input=kn_emb_out,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
fc_bw = fluid.layers.fc(input=kn_emb_out,
size=rnn_hidden_size * 3,
param_attr=para_attr,
bias_attr=False)
gru_forward = fluid.layers.dynamic_gru(
input=fc_fw,
size=rnn_hidden_size,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
gru_backward = fluid.layers.dynamic_gru(
input=fc_bw,
size=rnn_hidden_size,
is_reverse=True,
param_attr=para_attr,
bias_attr=bias_attr,
candidate_activation='relu')
memory_encoder_out = fluid.layers.concat(
input=[gru_forward, gru_backward], axis=1)
memory_encoder_proj_out = fluid.layers.fc(input=memory_encoder_out,
size=256,
bias_attr=False)
return memory_encoder_out, memory_encoder_proj_out
def _build_model(self,
context_ids,
context_pos_ids,
context_segment_ids,
context_attn_mask,
kn_ids):
context_emb_out = fluid.layers.embedding(
input=context_ids,
size=[self._voc_size, self._emb_size],
param_attr=fluid.ParamAttr(name=self._context_emb_name),
is_sparse=False)
context_position_emb_out = fluid.layers.embedding(
input=context_pos_ids,
size=[self._max_position_seq_len, self._emb_size],
param_attr=fluid.ParamAttr(name=self._context_pos_emb_name), )
context_segment_emb_out = fluid.layers.embedding(
input=context_segment_ids,
size=[self._sent_types, self._emb_size],
param_attr=fluid.ParamAttr(name=self._context_segment_emb_name), )
context_emb_out = context_emb_out + context_position_emb_out
context_emb_out = context_emb_out + context_segment_emb_out
context_emb_out = pre_process_layer(
context_emb_out, 'nd', self._prepostprocess_dropout, name='context_pre_encoder')
n_head_context_attn_mask = fluid.layers.stack(
x=[context_attn_mask] * self._n_head, axis=1)
n_head_context_attn_mask.stop_gradient = True
self._context_enc_out = encoder(
enc_input=context_emb_out,
attn_bias=n_head_context_attn_mask,
n_layer=self._n_layer,
n_head=self._n_head,
d_key=self._emb_size // self._n_head,
d_value=self._emb_size // self._n_head,
d_model=self._emb_size,
d_inner_hid=self._emb_size * 4,
prepostprocess_dropout=self._prepostprocess_dropout,
attention_dropout=self._attention_dropout,
relu_dropout=0,
hidden_act=self._hidden_act,
preprocess_cmd="an",
postprocess_cmd="dan",
name='context_encoder')
if kn_ids:
self.memory_encoder_out, self.memory_encoder_proj_out = \
self._build_memory_network(kn_ids)
def get_context_output(self, context_next_sent_index, task_name):
if "kn" in task_name:
cls_feats = self.get_context_response_memory(context_next_sent_index)
else:
cls_feats = self.get_pooled_output(context_next_sent_index)
return cls_feats
def get_context_response_memory(self, context_next_sent_index):
context_out = self.get_pooled_output(context_next_sent_index)
kn_context = self.attention(context_out, \
self.memory_encoder_out, self.memory_encoder_proj_out)
cls_feats = fluid.layers.concat(input=[context_out, kn_context], axis=1)
return cls_feats
def attention(self, hidden_mem, encoder_vec, encoder_vec_proj):
concated = fluid.layers.sequence_expand(
x=hidden_mem, y=encoder_vec_proj)
concated = encoder_vec_proj + concated
concated = fluid.layers.tanh(x=concated)
attention_weights = fluid.layers.fc(input=concated,
size=1,
act=None,
bias_attr=False)
attention_weights = fluid.layers.sequence_softmax(
input=attention_weights)
weigths_reshape = fluid.layers.reshape(x=attention_weights, shape=[-1])
scaled = fluid.layers.elementwise_mul(
x=encoder_vec, y=weigths_reshape, axis=0)
context = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return context
def get_sequence_output(self):
return (self._context_enc_out, self._response_enc_out)
def get_pooled_output(self, context_next_sent_index):
context_out = self.get_pooled(context_next_sent_index)
return context_out
def get_pooled(self, next_sent_index):
"""Get the first feature of each sequence for classification"""
reshaped_emb_out = fluid.layers.reshape(
x=self._context_enc_out, shape=[-1, self._emb_size], inplace=True)
next_sent_index = fluid.layers.cast(x=next_sent_index, dtype='int32')
next_sent_feat = fluid.layers.gather(
input=reshaped_emb_out, index=next_sent_index)
next_sent_feat = fluid.layers.fc(
input=next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name="pooled_fc.w_0",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr="pooled_fc.b_0")
return next_sent_feat
def get_pooled_output_no_share(self, context_next_sent_index, response_next_sent_index):
"""get pooled embedding"""
self._context_reshaped_emb_out = fluid.layers.reshape(
x=self._context_enc_out, shape=[-1, self._emb_size], inplace=True)
context_next_sent_index = fluid.layers.cast(x=context_next_sent_index, dtype='int32')
context_out = fluid.layers.gather(
input=self._context_reshaped_emb_out, index=context_next_sent_index)
context_out = fluid.layers.fc(
input=context_out,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name="pooled_context_fc.w_0",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr="pooled_context_fc.b_0")
self._response_reshaped_emb_out = fluid.layers.reshape(
x=self._response_enc_out, shape=[-1, self._emb_size], inplace=True)
response_next_sent_index = fluid.layers.cast(x=response_next_sent_index, dtype='int32')
response_next_sent_feat = fluid.layers.gather(
input=self._response_reshaped_emb_out, index=response_next_sent_index)
response_next_sent_feat = fluid.layers.fc(
input=response_next_sent_feat,
size=self._emb_size,
act="tanh",
param_attr=fluid.ParamAttr(
name="pooled_response_fc.w_0",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr="pooled_response_fc.b_0")
return context_out, response_next_sent_feat
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: utils.py
"""
from __future__ import print_function
import os
import six
import ast
import copy
import numpy as np
import paddle.fluid as fluid
def init_checkpoint(exe, init_checkpoint_path, main_program):
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
fluid.io.load_persistables(
exe, init_checkpoint_path, main_program=main_program)
print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe, pretraining_params_path, main_program):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print("Load pretraining parameters from {}".format(pretraining_params_path))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: __init__.py
"""
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: build_candidate_set_from_corpus.py
"""
from __future__ import print_function
import sys
import json
import random
import collections
reload(sys)
sys.setdefaultencoding('utf8')
def build_candidate_set_from_corpus(corpus_file, candidate_set_file):
"""
build candidate set from corpus
"""
candidate_set_gener = {}
candidate_set_mater = {}
candidate_set_list = []
slot_dict = {"topic_a": 1, "topic_b": 1}
with open(corpus_file, 'r') as f:
for i, line in enumerate(f):
conversation = json.loads(line.strip(), encoding="utf-8", \
object_pairs_hook=collections.OrderedDict)
chat_path = conversation["goal"]
knowledge = conversation["knowledge"]
session = conversation["conversation"]
topic_a = chat_path[0][1]
topic_b = chat_path[0][2]
domain_a = None
domain_b = None
cover_att_list = [[["topic_a", topic_a], ["topic_b", topic_b]]] * len(session)
for j, [s, p, o] in enumerate(knowledge):
p_key = ""
if topic_a.replace(' ', '') == s.replace(' ', ''):
p_key = "topic_a_" + p.replace(' ', '')
if u"领域" == p:
domain_a = o
elif topic_b.replace(' ', '') == s.replace(' ', ''):
p_key = "topic_b_" + p.replace(' ', '')
if u"领域" == p:
domain_b = o
for k, utterance in enumerate(session):
if k % 2 == 1: continue
if o in utterance and o != topic_a and o != topic_b and p_key != "":
cover_att_list[k].append([p_key, o])
slot_dict[p_key] = 1
assert domain_a is not None and domain_b is not None
for j, utterance in enumerate(session):
if j % 2 == 1: continue
key = '_'.join([domain_a, domain_b, str(j)])
cover_att = sorted(cover_att_list[j], lambda x, y: cmp(len(x[1]), len(y[1])), reverse=True)
utterance_gener = utterance
for [p_key, o] in cover_att:
utterance_gener = utterance_gener.replace(o, p_key)
if "topic_a_topic_a_" not in utterance_gener and \
"topic_a_topic_b_" not in utterance_gener and \
"topic_b_topic_a_" not in utterance_gener and \
"topic_b_topic_b_" not in utterance_gener:
if key in candidate_set_gener:
candidate_set_gener[key].append(utterance_gener)
else:
candidate_set_gener[key] = [utterance_gener]
utterance_mater = utterance
for [p_key, o] in [["topic_a", topic_a], ["topic_b", topic_b]]:
utterance_mater = utterance_mater.replace(o, p_key)
if key in candidate_set_mater:
candidate_set_mater[key].append(utterance_mater)
else:
candidate_set_mater[key] = [utterance_mater]
candidate_set_list.append(utterance_mater)
fout = open(candidate_set_file, 'w')
fout.write(json.dumps(candidate_set_gener, ensure_ascii=False, encoding="utf-8") + "\n")
fout.write(json.dumps(candidate_set_mater, ensure_ascii=False, encoding="utf-8") + "\n")
fout.write(json.dumps(candidate_set_list, ensure_ascii=False, encoding="utf-8") + "\n")
fout.write(json.dumps(slot_dict, ensure_ascii=False, encoding="utf-8"))
fout.close()
def main():
"""
main
"""
build_candidate_set_from_corpus(sys.argv[1], sys.argv[2])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
\ No newline at end of file
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: build_dict.py
"""
from __future__ import print_function
import sys
reload(sys)
sys.setdefaultencoding('utf8')
def build_dict(corpus_file, dict_file):
"""
build words dict
"""
dict = {}
max_frequency = 1
for line in open(corpus_file, 'r'):
conversation = line.strip().split('\t')
for i in range(1, len(conversation), 1):
words = conversation[i].split(' ')
for word in words:
if word in dict:
dict[word] = dict[word] + 1
if dict[word] > max_frequency:
max_frequency = dict[word]
else:
dict[word] = 1
dict["[PAD]"] = max_frequency + 4
dict["[UNK]"] = max_frequency + 3
dict["[CLS]"] = max_frequency + 2
dict["[SEP]"] = max_frequency + 1
words = sorted(dict.items(), lambda x, y: cmp(x[1], y[1]), reverse=True)
fout = open(dict_file, 'w')
for word, frequency in words:
fout.write(word + '\n')
fout.close()
def main():
"""
main
"""
if len(sys.argv) < 3:
print("Usage: " + sys.argv[0] + " corpus_file dict_file")
exit()
build_dict(sys.argv[1], sys.argv[2])
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: construct_candidate.py
"""
from __future__ import print_function
import sys
import json
import random
import collections
reload(sys)
sys.setdefaultencoding('utf8')
def load_candidate_set(candidate_set_file):
"""
load candidate set
"""
candidate_set = []
for line in open(candidate_set_file):
candidate_set.append(json.loads(line.strip(), encoding="utf-8"))
return candidate_set
def candidate_slection(candidate_set, knowledge_dict, slot_dict, candidate_num=10):
"""
candidate slection
"""
random.shuffle(candidate_set)
candidate_legal = []
for candidate in candidate_set:
is_legal = True
for slot in slot_dict:
if slot in ["topic_a", "topic_b"]:
continue
if slot in candidate:
if slot not in knowledge_dict:
is_legal = False
break
w_ = random.choice(knowledge_dict[slot])
candidate = candidate.replace(slot, w_)
for slot in ["topic_a", "topic_b"]:
if slot in candidate:
if slot not in knowledge_dict:
is_legal = False
break
w_ = random.choice(knowledge_dict[slot])
candidate = candidate.replace(slot, w_)
if is_legal and candidate not in candidate_legal:
candidate_legal.append(candidate)
if len(candidate_legal) >= candidate_num:
break
return candidate_legal
def get_candidate_for_conversation(conversation, candidate_set, candidate_num=10):
"""
get candidate for conversation
"""
candidate_set_gener, candidate_set_mater, candidate_set_list, slot_dict = candidate_set
chat_path = conversation["goal"]
knowledge = conversation["knowledge"]
history = conversation["history"]
topic_a = chat_path[0][1]
topic_b = chat_path[0][2]
domain_a = None
domain_b = None
knowledge_dict = {"topic_a":[topic_a], "topic_b":[topic_b]}
for i, [s, p, o] in enumerate(knowledge):
p_key = ""
if topic_a.replace(' ', '') == s.replace(' ', ''):
p_key = "topic_a_" + p.replace(' ', '')
if u"领域" == p:
domain_a = o
elif topic_b.replace(' ', '') == s.replace(' ', ''):
p_key = "topic_b_" + p.replace(' ', '')
if u"领域" == p:
domain_b = o
if p_key == "":
continue
if p_key in knowledge_dict:
knowledge_dict[p_key].append(o)
else:
knowledge_dict[p_key] = [o]
assert domain_a is not None and domain_b is not None
key = '_'.join([domain_a, domain_b, str(len(history))])
candidate_legal = []
if key in candidate_set_gener:
candidate_legal.extend(candidate_slection(candidate_set_gener[key],
knowledge_dict, slot_dict,
candidate_num = candidate_num - len(candidate_legal)))
if len(candidate_legal) < candidate_num and key in candidate_set_mater:
candidate_legal.extend(candidate_slection(candidate_set_mater[key],
knowledge_dict, slot_dict,
candidate_num = candidate_num - len(candidate_legal)))
if len(candidate_legal) < candidate_num:
candidate_legal.extend(candidate_slection(candidate_set_list,
knowledge_dict, slot_dict,
candidate_num = candidate_num - len(candidate_legal)))
return candidate_legal
def construct_candidate_for_corpus(corpus_file, candidate_set_file, candidate_file, candidate_num=10):
"""
construct candidate for corpus
case of data in corpus_file:
{
"goal": [["START", "休 · 劳瑞", "蕾切儿 · 哈伍德"]],
"knowledge": [["休 · 劳瑞", "评论", "完美 的 男人"]],
"history": ["你 对 明星 有没有 到 迷恋 的 程度 呢 ?",
"一般 吧 , 毕竟 年纪 不 小 了 , 只是 追星 而已 。"]
}
case of data in candidate_file:
{
"goal": [["START", "休 · 劳瑞", "蕾切儿 · 哈伍德"]],
"knowledge": [["休 · 劳瑞", "评论", "完美 的 男人"]],
"history": ["你 对 明星 有没有 到 迷恋 的 程度 呢 ?",
"一般 吧 , 毕竟 年纪 不 小 了 , 只是 追星 而已 。"],
"candidate": ["我 说 的 是 休 · 劳瑞 。",
"我 说 的 是 休 · 劳瑞 。"]
}
"""
candidate_set = load_candidate_set(candidate_set_file)
fout_text = open(candidate_file, 'w')
with open(corpus_file, 'r') as f:
for i, line in enumerate(f):
conversation = json.loads(line.strip(), encoding="utf-8", \
object_pairs_hook=collections.OrderedDict)
candidates = get_candidate_for_conversation(conversation,
candidate_set,
candidate_num=candidate_num)
conversation["candidate"] = candidates
conversation = json.dumps(conversation, ensure_ascii=False, encoding="utf-8")
fout_text.write(conversation + "\n")
fout_text.close()
def main():
"""
main
"""
construct_candidate_for_corpus(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4]))
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: conversation_client.py
"""
from __future__ import print_function
import sys
import socket
reload(sys)
sys.setdefaultencoding('utf8')
SERVER_IP = "127.0.0.1"
SERVER_PORT = 8601
def conversation_client(text):
"""
conversation_client
"""
mysocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
mysocket.connect((SERVER_IP, SERVER_PORT))
mysocket.sendall(text.encode())
result = mysocket.recv(4096).decode()
mysocket.close()
return result
def main():
"""
main
"""
if len(sys.argv) < 2:
print("Usage: " + sys.argv[0] + " eval_file")
exit()
for line in open(sys.argv[1]):
response = conversation_client(line.strip())
print(response)
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
#!/usr/bin/env python
# -*- coding: utf-8 -*-
######################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######################################################################
"""
File: conversation_server.py
"""
from __future__ import print_function
import sys
sys.path.append("../")
import socket
from thread import start_new_thread
from tools.conversation_strategy import load
from tools.conversation_strategy import predict
reload(sys)
sys.setdefaultencoding('utf8')
SERVER_IP = "127.0.0.1"
SERVER_PORT = 8601
print("starting conversation server ...")
print("binding socket ...")
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Bind socket to local host and port
try:
s.bind((SERVER_IP, SERVER_PORT))
except socket.error as msg:
print("Bind failed. Error Code : " + str(msg[0]) + " Message " + msg[1])
exit()
# Start listening on socket
s.listen(10)
print("bind socket success !")
print("loading model...")
model = load()
print("load model success !")
print("start conversation server success !")
def clientthread(conn, addr):
"""
client thread
"""
logstr = "addr:" + addr[0] + "_" + str(addr[1])
try:
# Receiving from client
param = conn.recv(4096).decode()
logstr += "\tparam:" + param
if param is not None:
response = predict(model, param.strip())
logstr += "\tresponse:" + response
conn.sendall(response.encode())
conn.close()
print(logstr + "\n")
except Exception as e:
print(logstr + "\n", e)
while True:
conn, addr = s.accept()
start_new_thread(clientthread, (conn, addr))
s.close()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册