未验证 提交 a72760df 编写于 作者: Y Yibing Liu 提交者: GitHub

Upload mmpms (#2621)

* Upload mmpms

* Adapt to py2
上级 f043ee3c
# IJCAI2019-MMPMS
## 1. Introduction
This is an implementation of **MMPMS** model for the one-to-many problem in open-domain conversation. MMPMS employs a *multi-mapping* mechanism to capture the one-to-many responding regularities between an input post and its diverse responses with multiple mapping modules. MMPMS also incorporates a *posterior mapping selection* module to identify the mapping module corresponding to the target response for accurate optimization. Experiments on Weibo and Reddit conversation dataset demonstrate the capacity of MMPMS in generating multiple diverse and informative responses.For more details, see the IJCAI-2019 paper: [Generating Multiple Diverse Responses with Multi-Mapping and Posterior Mapping Selection](https://arxiv.org/abs/1906.01781).
<p align="center">
<img src="./images/architechture.png" width="500">
</p>
## 2. Quick Start
### Requirements
- Python >= 3.6
- PaddlePaddle >= 1.3.2 && <= 1.4.1
- NLTK
### Data Preparation
Prepare one-turn conversation dataset (e.g. [Weibo](https://www.aclweb.org/anthology/P15-1152) and [Reddit](https://www.ijcai.org/proceedings/2018/0643.pdf)), and put the train/valid/test data files into the `data` folder:
```
data/
├── dial.train
├── dial.valid
└── dial.test
```
In the data file, each line is a post-response pair formatted by `post \t response`.
Prepare pre-trained word embedding (e.g. [sgns.weibo.300d.txt](https://pan.baidu.com/s/1zbuUJEEEpZRNHxZ7Gezzmw) for Weibo and [glove.840B.300d.txt](http://nlp.stanford.edu/data/glove.840B.300d.zip) for Reddit), and put it into the `data` folder. The first line of pre-trained word embedding file should be formatted by `num_words embedding_dim`.
Preprocess the data by running:
```pyhton
python preprocess.py
```
The vocabulary and the preprocessed data will be saved in the same `data` folder:
```
data/
├── dial.train.pkl
├── dial.valid.pkl
├── dial.test.pkl
└── vocab.json
```
### Train
To train a model, run:
```python
python run.py --data_dir DATA_DIR
```
The logs and model parameters will be saved to the `./output` folder by default.
### Test
Generate text result to `RESULT_FILE` with the saved model in `MODEL_DIR` by running:
```python
python run.py --infer --model_dir MODEL_DIR --result_file RESULT_FILE
```
The `RESULT_FILE` will be a Json file containing the input post, target response and predicted response from each mapping module.
Then evaluate the generation result with the following command:
```python
python eval.py RESULT_FILE
```
**Note**:
- The data files in the `data` folder are just samples to illustrate the data format. Remember to replace them with your data.
- To use GPU in Train or Test, please set the `GPU_ID` first: `export CUDA_VISIBLE_DEVICES=GPU_ID`.
## 3. Citation
If you use any source code included in this toolkit in your work, please cite the following paper:
```
@inproceedings{IJCAI2019-MMPMS ,
title={Generating Multiple Diverse Responses with Multi-Mapping and Posterior Mapping Selection},
author={Chaotao Chen, Jinhua Peng, Fan Wang, Jun Xu and Hua Wu},
booktitle={Proceedings of the 28th International Joint Conference on Artificial Intelligence},
pages={ -- },
year={2019}
}
```
此差异已折叠。
此差异已折叠。
此差异已折叠。
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import codecs
import sys
import json
from random import shuffle
from mmpms.utils.metrics import Metric, bleu, distinct
NUM_MULTI_RESPONSES = 5
def evaluate_generation(results):
tgt = [result["response"].split(" ") for result in results]
tgt_multi = [x for x in tgt for _ in range(NUM_MULTI_RESPONSES)]
preds = [
list(map(lambda s: s.split(" "), result["preds"])) for result in results
]
# Shuffle predictions
for n in range(len(preds)):
shuffle(preds[n])
# Single response generation
pred = [ps[0] for ps in preds]
bleu1, bleu2 = bleu(pred, tgt)
dist1, dist2 = distinct(pred)
print("Random 1 candidate: " + "BLEU-1/2: {:.3f}/{:.3f} ".format(
bleu1, bleu2) + "DIST-1/2: {:.3f}/{:.3f}".format(dist1, dist2))
# Multiple response generation
pred = [ps[:5] for ps in preds]
pred = [p for ps in pred for p in ps]
bleu1, bleu2 = bleu(pred, tgt_multi)
dist1, dist2 = distinct(pred)
print("Random {} candidates: ".format(
NUM_MULTI_RESPONSES) + "BLEU-1/2: {:.3f}/{:.3f} ".format(bleu1, bleu2)
+ "DIST-1/2: {:.3f}/{:.3f}".format(dist1, dist2))
def main():
result_file = sys.argv[1]
with codecs.open(result_file, "r", encoding="utf-8") as fp:
results = json.load(fp)
evaluate_generation(results)
if __name__ == '__main__':
main()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import codecs
import os
import json
import time
import shutil
from collections import defaultdict
from mmpms.utils.logging import getLogger
from mmpms.utils.metrics import Metric, bleu, distinct
def evaluate(model, data_iter):
metrics_tracker = defaultdict(Metric)
for batch in data_iter:
metrics = model.evaluate(inputs=batch)
for k, v in metrics.items():
metrics_tracker[k].update(v, batch["size"])
return metrics_tracker
def flatten_batch(batch):
examples = []
for vs in zip(*batch.values()):
ex = dict(zip(batch.keys(), vs))
examples.append(ex)
return examples
def infer(model, data_iter, parse_dict, save_file=None):
results = []
for batch in data_iter:
result = model.infer(inputs=batch)
batch_result = {}
# denumericalization
for k, parse_fn in parse_dict.items():
if k in result:
batch_result[k] = parse_fn(result[k])
results += flatten_batch(batch_result)
if save_file is not None:
with codecs.open(save_file, "w", encoding="utf-8") as fp:
json.dump(results, fp, ensure_ascii=False, indent=2)
print("Saved inference results to '{}'".format(save_file))
return results
class Engine(object):
def __init__(self,
model,
valid_metric_name="-loss",
num_epochs=1,
save_dir=None,
log_steps=None,
valid_steps=None,
logger=None):
self.model = model
self.is_decreased_valid_metric = valid_metric_name[0] == "-"
self.valid_metric_name = valid_metric_name[1:]
self.num_epochs = num_epochs
self.save_dir = save_dir or "./"
self.log_steps = log_steps
self.valid_steps = valid_steps
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.logger = logger or logging.getLogger(
os.path.join(self.save_dir, "run.log"))
best_valid_metric = float("inf") if self.is_decreased_valid_metric \
else -float("inf")
self.state = {
"epoch": 0,
"iteration": 0,
"best_valid_metric": best_valid_metric
}
@property
def epoch(self):
return self.state["epoch"]
@property
def iteration(self):
return self.state["iteration"]
@property
def best_valid_metric(self):
return self.state["best_valid_metric"]
def train_epoch(self, train_iter, valid_iter=None):
self.state["epoch"] += 1
num_batches = len(train_iter)
metrics_tracker = defaultdict(Metric)
for batch_id, batch in enumerate(train_iter, 1):
# Do a training iteration
start_time = time.time()
metrics = self.model.train(inputs=batch)
elapsed = time.time() - start_time
for k, v in metrics.items():
metrics_tracker[k].update(v, batch["size"])
metrics_tracker["time"].update(elapsed)
self.state["iteration"] += 1
if self.log_steps and batch_id % self.log_steps == 0:
metrics_message = [
"{}-{}".format(name.upper(), metric.val)
for name, metric in metrics_tracker.items()
]
message_prefix = "[Train][{}][{}/{}]".format(
self.epoch, batch_id, num_batches)
message = " ".join([message_prefix] + metrics_message)
self.logger.info(message)
if self.valid_steps and valid_iter is not None and \
batch_id % self.valid_steps == 0:
self.evaluate(valid_iter)
if valid_iter is not None:
self.evaluate(valid_iter)
def save(self, is_best):
model_file = os.path.join(self.save_dir,
"model_epoch_{}".format(self.epoch))
self.model.save(model_file)
self.logger.info("Saved model to '{}'".format(model_file))
if is_best:
best_model_file = os.path.join(self.save_dir, "best_model")
if os.path.isdir(model_file):
if os.path.exists(best_model_file):
shutil.rmtree(best_model_file)
shutil.copytree(model_file, best_model_file)
else:
shutil.copyfile(model_file, best_model_file)
self.logger.info("Saved best model to '{}' "
"with new best valid metric "
"{}-{}".format(best_model_file,
self.valid_metric_name.upper(),
self.best_valid_metric))
def load(self, model_dir):
self.model.load(model_dir)
self.logger.info("Loaded model checkpoint from {}".format(model_dir))
def evaluate(self, data_iter, is_save=True):
metrics_tracker = evaluate(self.model, data_iter)
metrics_message = [
"{}-{}".format(name.upper(), metric.avg)
for name, metric in metrics_tracker.items()
]
message_prefix = "[Valid][{}]".format(self.epoch)
message = " ".join([message_prefix] + metrics_message)
self.logger.info(message)
if is_save:
cur_valid_metric = metrics_tracker.get(self.valid_metric_name).avg
if self.is_decreased_valid_metric:
is_best = cur_valid_metric < self.best_valid_metric
else:
is_best = cur_valid_metric > self.best_valid_metric
if is_best:
self.state["best_valid_metric"] = cur_valid_metric
self.save(is_best)
#!/usr/bin/env python
"""
description
"""
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
UNK = "[unk]"
BOS = "[bos]"
EOS = "[eos]"
NUM = "[num]"
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import math
import numpy as np
from collections import defaultdict
import paddle
import paddle.fluid as fluid
def data2lodtensor(data, place):
lod = []
while isinstance(data[0], list):
lod.append(list(map(len, data)))
data = [x for xs in data for x in xs]
array = np.array(data, dtype="int64")
if len(array.shape) == 1:
array = array[:, None]
tensor = fluid.LoDTensor()
tensor.set(array, place)
if len(lod) > 0:
tensor.set_recursive_sequence_lengths(lod)
return tensor
class DataLoader(object):
def __init__(self,
data,
batch_size,
shuffle=False,
buf_size=4096,
use_gpu=False):
def data_reader():
return data
if shuffle:
self.reader = paddle.batch(
paddle.reader.shuffle(
data_reader, buf_size=buf_size),
batch_size=batch_size)
else:
self.reader = paddle.batch(data_reader, batch_size=batch_size)
self.num_batches = math.ceil(len(data) / batch_size)
self.place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
def __len__(self):
return self.num_batches
def __iter__(self):
for examples in self.reader():
batch_size = len(examples)
batch = defaultdict(list)
for ex in examples:
for k, v in ex.items():
batch[k].append(v)
batch = {k: data2lodtensor(v, self.place) for k, v in batch.items()}
batch["size"] = batch_size
yield batch
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import codecs
import re
import time
import json
import pickle
from collections import Counter
from mmpms.inputters.vocabulary import Vocabulary
from mmpms.inputters.constant import UNK, BOS, EOS, NUM
def tokenize(s):
s = re.sub('\d+', NUM, s).lower()
tokens = s.split(' ')
return tokens
class PostResponseDataset(object):
def __init__(self,
tokenize_fn=tokenize,
min_count=0,
max_vocab_size=None,
min_len=0,
max_len=100,
embed_file=None):
self.tokenize_fn = tokenize_fn
self.vocab = Vocabulary(
min_count=min_count, max_size=max_vocab_size, embed_file=embed_file)
self.min_len = min_len
self.max_len = max_len
def build_vocab(self, data_file):
examples = self.read(data_file)
counter = Counter()
print("Building vocabulary ...")
for example in examples:
counter.update(example["post"])
counter.update(example["response"])
self.vocab.build(counter)
def save_vocab(self, vocab_file):
vocab_dict = self.vocab.dump()
start = time.time()
with codecs.open(vocab_file, "w", encoding="utf-8") as fp:
json.dump(vocab_dict, fp, ensure_ascii=False)
elapsed = time.time() - start
print("Saved vocabulary to '{}' (elapsed {:.2f}s)".format(vocab_file,
elapsed))
def load_vocab(self, vocab_file):
print("Loading vocabulary from '{}' ...".format(vocab_file))
start = time.time()
with codecs.open(vocab_file, "r", encoding="utf-8") as fp:
vocab_dict = json.load(fp)
elapsed = time.time() - start
self.vocab.load(vocab_dict)
vocab_size = self.vocab.size()
print("Loaded vocabulary of size {} (elapsed {}s)".format(vocab_size,
elapsed))
def indices2string(self, indices):
tokens = [self.vocab.itos[idx] for idx in indices]
bos_token = self.vocab.bos_token
if bos_token and tokens[0] == bos_token:
tokens = tokens[1:]
eos_token = self.vocab.eos_token
string = []
for tok in tokens:
if tok != eos_token:
string.append(tok)
else:
break
string = " ".join(string)
return string
def tokens2indices(self, tokens):
indices = [
self.vocab.stoi.get(tok, self.vocab.unk_id) for tok in tokens
]
return indices
def numericalize(self, tokens):
element = tokens[0]
if isinstance(element, list):
return [self.numericalize(s) for s in tokens]
else:
return self.tokens2indices(tokens)
def denumericalize(self, indices):
element = indices[0]
if isinstance(element, list):
return [self.denumericalize(x) for x in indices]
else:
return self.indices2string(indices)
def build_examples(self, data_file):
print("Building examples from '{}' ...".format(data_file))
data = self.read(data_file)
examples = []
print("Numericalizing examples ...")
for ex in data:
example = {}
post, response = ex["post"], ex["response"]
post = self.numericalize(post)
response = self.numericalize(response)
example["post"] = post
example["response"] = [self.vocab.bos_id] + response
example["label"] = response + [self.vocab.eos_id]
examples.append(example)
return examples
def save_examples(self, examples, filename):
start = time.time()
with open(filename, "wb") as fp:
pickle.dump(examples, fp)
elapsed = time.time() - start
print("Saved examples to '{}' (elapsed {:.2f}s)".format(filename,
elapsed))
def load_examples(self, filename):
print("Loading examples from '{}' ...".format(filename))
start = time.time()
with open(filename, "rb") as fp:
examples = pickle.load(fp)
elapsed = time.time() - start
print("Loaded {} examples (elapsed {:.2f}s)".format(
len(examples), elapsed))
return examples
def read(self, data_file):
examples = []
ignored = 0
def filter_pred(utt):
""" Filter utterance. """
return self.min_len <= len(utt) <= self.max_len
print("Reading examples from '{}' ...".format(data_file))
with codecs.open(data_file, "r", encoding="utf-8") as f:
for line in f:
post, response = line.strip().split("\t")
post = self.tokenize_fn(post)
response = self.tokenize_fn(response)
if filter_pred(post) and filter_pred(response):
examples.append({"post": post, "response": response})
else:
ignored += 1
print("Read {} examples ({} filtered)".format(len(examples), ignored))
return examples
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from __future__ import division
import codecs
from mmpms.inputters.constant import UNK, BOS, EOS
class Vocabulary(object):
unk_token = UNK
bos_token = BOS
eos_token = EOS
def __init__(self, min_count=0, max_size=None, specials=[],
embed_file=None):
self.min_count = min_count
self.max_size = max_size
self.embed_file = embed_file
self.specials = [self.unk_token, self.bos_token, self.eos_token]
for token in specials:
if token not in self.specials:
self.specials.append(token)
self.itos = []
self.stoi = {}
self.embeddings = None
@property
def unk_id(self):
return self.stoi.get(self.unk_token)
@property
def bos_id(self):
return self.stoi.get(self.bos_token)
@property
def eos_id(self):
return self.stoi.get(self.eos_token)
def __len__(self):
return len(self.itos)
def size(self):
return len(self.itos)
def build(self, counter):
# frequencies of special tokens are not counted when building vocabulary
# in frequency order
for tok in self.specials:
del counter[tok]
if len(counter) == 0:
return
self.itos = list(self.specials)
if self.max_size is not None:
self.max_size += len(self.itos)
# sort by frequency, then alphabetically
tokens_frequencies = sorted(counter.items(), key=lambda tup: tup[0])
tokens_frequencies.sort(key=lambda tup: tup[1], reverse=True)
cover = 0
for token, count in tokens_frequencies:
if count < self.min_count or len(self.itos) == self.max_size:
break
self.itos.append(token)
cover += count
cover = cover / sum(count for _, count in tokens_frequencies)
self.stoi = {token: i for i, token in enumerate(self.itos)}
print("Built vocabulary of size {} ".format(self.size()) +
"(coverage: {:.3f})".format(cover))
if self.embed_file is not None:
self.embeddings = self.build_word_embeddings(self.embed_file)
def build_word_embeddings(self, embed_file):
cover = 0
print("Building word embeddings from '{}' ...".format(embed_file))
with codecs.open(embed_file, "r", encoding="utf-8") as f:
num, dim = map(int, f.readline().strip().split())
embeds = [[0] * dim] * len(self.stoi)
for line in f:
cols = line.rstrip().split()
w, vs = cols[0], cols[1:]
if w in self.stoi:
try:
vs = [float(x) for x in vs]
except Exception:
vs = []
if len(vs) == dim:
embeds[self.stoi[w]] = vs
cover += 1
rate = cover / len(embeds)
print("Built {} {}-D pretrained word embeddings ".format(cover, dim) +
"(coverage: {:.3f})".format(rate))
return embeds
def dump(self):
vocab_dict = {"itos": self.itos, "embeddings": self.embeddings}
return vocab_dict
def load(self, vocab_dict):
self.itos = vocab_dict["itos"]
self.stoi = {tok: i for i, tok in enumerate(self.itos)}
self.embeddings = vocab_dict["embeddings"]
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from paddle.fluid.layers import *
from mmpms.layers.layers_wrapper import *
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
"""
Wrappers for fluid.layers. It helps to easily share parameters between layers.
"""
import operator
from collections import OrderedDict
import paddle.fluid.layers as layers
import paddle.fluid.unique_name as unique_name
from paddle.fluid.param_attr import ParamAttr
def update_attr(attr, name, prefix=None, suffix="W"):
if attr == False:
return False
if prefix:
name = prefix + "." + name
new_name = unique_name.generate(name + "." + suffix)
if attr is None:
attr = ParamAttr(name=new_name)
elif attr.name is None:
attr.name = new_name
return attr
class BaseLayer(object):
def __init__(self):
self._parameters = OrderedDict()
self._layers = OrderedDict()
def __getattr__(self, name):
if '_parameters' in self.__dict__:
_parameters = self.__dict__['_parameters']
if name in _parameters:
return _parameters[name]
if '_layers' in self.__dict__:
_layers = self.__dict__['_layers']
if name in _layers:
return _layers[name]
if name in self.__dict__:
return self.__dict__[name]
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
if isinstance(value, ParamAttr):
self._parameters[name] = value
remove_from(self.__dict__, self._layers)
elif isinstance(value, BaseLayer):
self._layers[name] = value
remove_from(self.__dict__, self._parameters)
else:
object.__setattr__(self, name, value)
def __call__(self, *args, **kwargs):
raise NotImplementedError
class LayerList(BaseLayer):
def __init__(self, layers):
super(LayerList, self).__init__()
self += layers
def _get_abs_string_index(self, idx):
"""Get the absolute index for the list of layers"""
idx = operator.index(idx)
if not (-len(self) <= idx < len(self)):
raise IndexError('index {} is out of range'.format(idx))
if idx < 0:
idx += len(self)
return str(idx)
def __getitem__(self, idx):
if isinstance(idx, slice):
return self.__class__(list(self._layers.values())[idx])
else:
return self._layers[self._get_abs_string_index(idx)]
def __setitem__(self, idx, layer):
idx = self._get_abs_string_index(idx)
return setattr(self, str(idx), layer)
def __delitem__(self, idx):
if isinstance(idx, slice):
for k in range(len(self._layers))[idx]:
delattr(self, str(k))
else:
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._layers is being reconstructed with layers after deletion
str_indices = [str(i) for i in range(len(self._layers))]
self._layers = OrderedDict(
list(zip(str_indices, self._layers.values())))
def __len__(self):
return len(self._layers)
def __iter__(self):
return iter(self._layers.values())
def __iadd__(self, layers):
return self.extend(layers)
def extend(self, layers):
if not isinstance(layers, (list, tuple)):
raise TypeError("LayerList.extend should be called with a "
"list or tuple, but got " + type(layers).__name__)
offset = len(self)
for i, layer in enumerate(layers):
self._layers[str(offset + i)] = layer
return self
class Sequential(BaseLayer):
def __init__(self, *layers):
super(Sequential, self).__init__()
for idx, layer in enumerate(layers):
if not isinstance(layer, BaseLayer):
raise TypeError("{} is not a BaseLayer subclass".format(
type(layer)))
self._layers[str(idx)] = layer
def __call__(self, input):
for layer in self._layers.values():
input = layer(input)
return input
class Embedding(BaseLayer):
def __init__(self,
size,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32',
name=None):
super(Embedding, self).__init__()
self.name = name or "Embedding"
self.size = size
self.is_sparse = is_sparse
self.is_distributed = False
self.padding_idx = padding_idx
self.param_attr = update_attr(param_attr, self.name, suffix="W")
self.dtype = dtype
def __call__(self, input):
return layers.embedding(
input=input,
size=self.size,
is_sparse=self.is_sparse,
is_distributed=self.is_distributed,
padding_idx=self.padding_idx,
param_attr=self.param_attr,
dtype=self.dtype)
class FC(BaseLayer):
def __init__(self,
size,
num_flatten_dims=1,
param_attr=None,
bias_attr=None,
act=None,
is_test=False,
name=None):
super(FC, self).__init__()
self.name = name or "FC"
self.size = size
self.num_flatten_dims = num_flatten_dims
self.param_attr = update_attr(param_attr, self.name, suffix="W")
self.bias_attr = update_attr(bias_attr, self.name, suffix="b")
self.act = act
self.is_test = False
def __call__(self, input, name=None):
assert not isinstance(input, (list, tuple))
return layers.fc(input=input,
size=self.size,
num_flatten_dims=self.num_flatten_dims,
param_attr=self.param_attr,
bias_attr=self.bias_attr,
act=self.act,
is_test=self.is_test,
name=name)
class DynamicGRU(BaseLayer):
def __init__(self,
hidden_dim,
param_attr=None,
bias_attr=None,
input_param_attr=None,
input_bias_attr=None,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
origin_mode=False,
name=None):
super(DynamicGRU, self).__init__()
self.name = name or "DynamicGRU"
self.hidden_dim = hidden_dim
self.param_attr = update_attr(param_attr, self.name, suffix="hidden.W")
self.bias_attr = update_attr(bias_attr, self.name, suffix="hidden.b")
self.input_param_attr = update_attr(
input_param_attr, self.name, suffix="input.W")
self.input_bias_attr = update_attr(
input_bias_attr, self.name, suffix="input.b")
self.is_reverse = is_reverse
self.gate_activation = gate_activation
self.candidate_activation = candidate_activation
self.origin_mode = origin_mode
def __call__(self, input, state=None):
gru_input = layers.fc(input=input,
size=self.hidden_dim * 3,
param_attr=self.input_param_attr,
bias_attr=self.input_bias_attr)
return layers.dynamic_gru(
input=gru_input,
size=self.hidden_dim,
param_attr=self.param_attr,
bias_attr=self.bias_attr,
is_reverse=self.is_reverse,
gate_activation=self.gate_activation,
candidate_activation=self.candidate_activation,
h_0=state,
origin_mode=self.origin_mode)
class GRU(BaseLayer):
def __init__(self,
hidden_dim,
num_layers=1,
bidirectional=False,
dropout=0.0,
name=None):
super(GRU, self).__init__()
if dropout > 0 and num_layers == 1:
raise ValueError(
"Non-zero dropout expects num_layers greater than 1")
self.name = name or "GRU"
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
self.num_directions = 2 if bidirectional else 1
self.dropout = dropout
rnns = []
for l in range(num_layers):
inners = []
inners.append(
DynamicGRU(
hidden_dim=hidden_dim, name="{}_l{}".format(self.name, l)))
if bidirectional:
inners.append(
DynamicGRU(
hidden_dim=hidden_dim,
name="{}_l{}_reverse".format(self.name, l),
is_reverse=True))
rnns.append(LayerList(inners))
self.rnns = LayerList(rnns)
def __call__(self, input, hidden=None):
if hidden is not None:
assert len(hidden) == self.num_layers
assert len(hidden[0]) == self.num_directions
else:
hidden = [[None] * self.num_directions] * self.num_layers
new_hidden = []
for l in range(self.num_layers):
layer_output = []
layer_hidden = []
for i, inner in enumerate(self.rnns[l]):
output = inner(input, hidden[l][i])
layer_output.append(output)
if inner.is_reverse:
layer_hidden.append(layers.sequence_first_step(output))
else:
layer_hidden.append(layers.sequence_last_step(output))
input = layers.concat(layer_output, axis=1)
if self.dropout > 0 and l + 1 < self.num_layers:
input = layers.dropout(
input,
dropout_prob=self.dropout,
dropout_implementation='upscale_in_train')
new_hidden.append(layers.concat(layer_hidden, axis=1))
return input, new_hidden
class GRUCell(BaseLayer):
def __init__(self,
hidden_dim,
param_attr=None,
bias_attr=None,
input_param_attr=None,
input_bias_attr=None,
activation='tanh',
gate_activation='sigmoid',
origin_mode=False,
name=None):
super(GRUCell, self).__init__()
self.name = name or "GRUCell"
self.hidden_dim = hidden_dim
self.param_attr = update_attr(param_attr, self.name, suffix="hidden.W")
self.bias_attr = update_attr(bias_attr, self.name, suffix="hidden.b")
self.input_param_attr = update_attr(
input_param_attr, self.name, suffix="input.W")
self.input_bias_attr = update_attr(
input_bias_attr, self.name, suffix="input.b")
self.activation = activation
self.gate_activation = gate_activation
self.origin_mode = origin_mode
def __call__(self, input, hidden):
gru_input = layers.fc(input=input,
size=self.hidden_dim * 3,
param_attr=self.input_param_attr,
bias_attr=self.input_bias_attr)
new_hidden, _, _ = layers.gru_unit(
input=gru_input,
hidden=hidden,
size=self.hidden_dim * 3,
param_attr=self.param_attr,
bias_attr=self.bias_attr,
activation=self.activation,
gate_activation=self.gate_activation,
origin_mode=self.origin_mode)
return new_hidden, new_hidden
class StackedGRUCell(BaseLayer):
def __init__(self, hidden_dim, num_layers=1, dropout=0.0, name=None):
super(StackedGRUCell, self).__init__()
if dropout > 0 and num_layers == 1:
raise ValueError(
"Non-zero dropout expects num_layers greater than 1")
self.name = name or "StackedGRUCell"
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.dropout = dropout
cells = [
GRUCell(
hidden_dim=hidden_dim, name="{}_l{}".format(self.name, l))
for l in range(self.num_layers)
]
self.cells = LayerList(cells)
def __call__(self, input, hidden):
assert len(hidden) == self.num_layers
new_hidden = []
for cell, hid in zip(self.cells, hidden):
input, new_hid = cell(input, hid)
new_hidden += [new_hid]
if self.dropout > 0:
input = layers.dropout(
input,
dropout_prob=self.dropout,
dropout_implementation='upscale_in_train')
output = new_hidden[-1]
return output, new_hidden
class Dropout(BaseLayer):
def __init__(self, dropout_prob, is_test=False, seed=None, name=None):
super(Dropout, self).__init__()
self.dropout_prob = dropout_prob
self.is_test = is_test
self.seed = seed
self.name = name
def __call__(self, input):
if self.dropout_prob > 0.0:
return layers.dropout(
input,
dropout_prob=self.dropout_prob,
is_test=self.is_test,
seed=self.seed,
name=self.name,
dropout_implementation='upscale_in_train')
else:
return input
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from __future__ import division
from __future__ import absolute_import
import math
import numpy as np
import paddle.fluid as fluid
import mmpms.layers as layers
from mmpms.models.model_base import Model
from mmpms.modules.embedder import Embedder
from mmpms.modules.encoder import GRUEncoder
from mmpms.modules.decoder import GRUDecoder
from mmpms.utils.misc import sequence_but, sequence_last
class MMPMS(Model):
def __init__(self, vocab, generator, hparams, optim_hparams, use_gpu=False):
self.vocab = vocab
self.generator = generator
self.vocab_size = self.vocab.size()
self.embed_dim = hparams.embed_dim
self.hidden_dim = hparams.hidden_dim
self.num_mappings = hparams.num_mappings
self.tau = hparams.tau
self.num_layers = hparams.num_layers
self.bidirectional = hparams.bidirectional
self.attn_mode = hparams.attn_mode
self.use_pretrained_embedding = hparams.use_pretrained_embedding
self.embed_init_scale = hparams.embed_init_scale
self.dropout = hparams.dropout
self.grad_clip = optim_hparams.grad_clip or 0
# Embedding
self.embedder = Embedder(
num_embeddings=self.vocab_size,
embedding_dim=self.embed_dim,
param_attr=fluid.ParamAttr(initializer=fluid.initializer.Uniform(
-self.embed_init_scale, self.embed_init_scale)),
name="embedder")
# Encoding
self.post_encoder = GRUEncoder(
hidden_dim=self.hidden_dim,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout,
name="post_encoder")
self.response_encoder = GRUEncoder(
hidden_dim=self.hidden_dim,
num_layers=self.num_layers,
bidirectional=self.bidirectional,
dropout=self.dropout,
name="response_encoder")
# Multi-Mapping
self.mappings = layers.LayerList([
layers.FC(size=self.hidden_dim, name="map_{}".format(i))
for i in range(self.num_mappings)
])
# Decoding
self.decoder = GRUDecoder(
hidden_dim=self.hidden_dim,
num_layers=self.num_layers,
attn_mode=self.attn_mode,
dropout=self.dropout,
name="decoder")
# Predictor
bound = math.sqrt(1 / self.hidden_dim)
if self.attn_mode == "none":
self.predictor = layers.Sequential(
layers.Dropout(dropout_prob=self.dropout),
layers.FC(
size=self.vocab_size,
act="softmax",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-bound, bound)),
name="predictor"))
else:
self.predictor = layers.Sequential(
layers.Dropout(dropout_prob=self.dropout),
layers.FC(size=self.hidden_dim, name="project"),
layers.FC(
size=self.vocab_size,
act="softmax",
param_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-bound, bound)),
bias_attr=fluid.ParamAttr(
initializer=fluid.initializer.Uniform(-bound, bound)),
name="predictor"), )
# Optimizer
Optimizer = getattr(fluid.optimizer, optim_hparams.optimizer)
self.optimizer = Optimizer(learning_rate=optim_hparams.lr)
super(MMPMS, self).__init__(use_gpu=use_gpu)
# Embedding Initialization
if self.use_pretrained_embedding:
self.embedder.from_pretrained(self.vocab.embeddings, self.place,
self.embed_init_scale)
def gumbel_softmax(self, logits, tau, eps=1e-10):
u = layers.uniform_random_batch_size_like(
logits, shape=[-1, self.num_mappings], min=0.0, max=1.0)
u.stop_gradient = True
gumbel = 0.0 - layers.log(eps - layers.log(u + eps))
y = logits + gumbel
return layers.softmax(y / tau)
def encode(self, post_inputs, response_inputs, is_training=False):
outputs = {}
post_enc_inputs = self.embedder(post_inputs)
post_outputs, post_hidden = self.post_encoder(post_enc_inputs)
post_hidden = post_hidden[-1]
# shape: (batch_size, num_mappings, hidden_dim)
candidate_hiddens = layers.stack(
[mapping(post_hidden) for mapping in self.mappings], axis=1)
response_enc_inputs = self.embedder(response_inputs)
_, response_hidden = self.response_encoder(response_enc_inputs)
response_hidden = response_hidden[-1]
# For simplicity, use the target responses in the same batch as negative examples
neg_response_hidden = layers.reverse(response_hidden, axis=0)
pos_logits = layers.reduce_sum(
post_hidden * response_hidden, dim=1, keep_dim=True)
neg_logits = layers.reduce_sum(
post_hidden * neg_response_hidden, dim=1, keep_dim=True)
outputs.update({"pos_logits": pos_logits, "neg_logits": neg_logits})
# shape: (batch_size, num_mappings)
similarity = layers.squeeze(
layers.matmul(
candidate_hiddens, layers.unsqueeze(
response_hidden, axes=[2])),
axes=[2])
post_probs = layers.softmax(similarity)
outputs.update({"post_probs": post_probs})
if is_training:
z = self.gumbel_softmax(
layers.log(post_probs + 1e-10), tau=self.tau)
else:
indices = layers.argmax(post_probs, axis=1)
z = layers.one_hot(
layers.reshape(
indices, shape=[-1, 1]), self.num_mappings)
# shape: (batch_size, hidden_size)
dec_hidden = layers.squeeze(
layers.matmul(
layers.unsqueeze(
z, axes=[1]), candidate_hiddens),
axes=[1])
state = {}
state["hidden"] = [dec_hidden] * self.num_layers
if self.attn_mode != "none":
state["memory"] = post_outputs
return outputs, state
def enumerate_encode(self, inputs, post_expand_lod):
post_enc_inputs = self.embedder(inputs)
post_outputs, post_hidden = self.post_encoder(post_enc_inputs)
post_hidden = post_hidden[-1]
# shape: (batch_size*num_mappings, hidden_dim)
dec_hidden = layers.stack(
[mapping(post_hidden) for mapping in self.mappings], axis=1)
dec_hidden = layers.reshape(dec_hidden, shape=[-1, self.hidden_dim])
post_outputs = layers.expand(
post_outputs, expand_times=[1, self.num_mappings])
post_outputs = layers.sequence_reshape(
post_outputs, new_dim=self.hidden_dim)
post_outputs = layers.lod_reset(post_outputs, y=post_expand_lod)
state = {}
state["hidden"] = [dec_hidden] * self.num_layers
if self.attn_mode != "none":
state["memory"] = post_outputs
return state
def decode(self, inputs, state, is_infer=True):
dec_inputs = self.embedder(inputs)
if is_infer:
dec_outputs, new_state = self.decoder.step(dec_inputs, state=state)
else:
dec_outputs = self.decoder(dec_inputs, state=state)
probs = self.predictor(dec_outputs)
if is_infer:
return probs, new_state
else:
return probs
def collect_metrics(self, outputs, label):
metrics = {}
loss = 0
# Seq2Seq NLL Loss
probs = outputs["probs"]
nll = layers.cross_entropy(input=probs, label=label)
ppl = layers.mean(
layers.exp(layers.sequence_pool(
nll, pool_type="average")),
name="ppl")
nll = layers.mean(
layers.sequence_pool(
nll, pool_type="sum"), name="nll")
metrics.update({"nll": nll, "ppl": ppl})
loss += nll
# Matching Loss
pos_logits = outputs["pos_logits"]
pos_label = layers.fill_constant_batch_size_like(
pos_logits, shape=[-1, 1], dtype="float32", value=1)
pos_label.stop_gradient = True
neg_logits = outputs["neg_logits"]
neg_label = layers.fill_constant_batch_size_like(
neg_logits, shape=[-1, 1], dtype="float32", value=0)
neg_label.stop_gradient = True
pos_loss = layers.sigmoid_cross_entropy_with_logits(pos_logits,
pos_label)
neg_loss = layers.sigmoid_cross_entropy_with_logits(neg_logits,
neg_label)
match = layers.mean(pos_loss + neg_loss)
pos_acc = layers.mean(
layers.cast(
layers.less_than(neg_label, pos_logits), dtype="float32"))
neg_acc = layers.mean(
layers.cast(
layers.less_than(neg_logits, neg_label), dtype="float32"))
acc = (pos_acc + neg_acc) / 2.0
metrics.update({"match": match, "match_acc": acc})
loss += match
metrics["loss"] = loss
return metrics
def build_program(self):
self.startup_program = fluid.Program()
self.train_program = fluid.Program()
with fluid.program_guard(self.train_program, self.startup_program):
# Input
post = layers.data(
name="post", shape=[1], lod_level=1, dtype="int64")
response = layers.data(
name="response", shape=[1], lod_level=1, dtype="int64")
label = layers.data(
name="label", shape=[1], lod_level=1, dtype="int64")
pos_response = layers.data(
name="pos_response", shape=[1], lod_level=1, dtype="int64")
self.eval_program = self.train_program.clone(for_test=True)
# Encode
outputs, state = self.encode(
post_inputs=post,
response_inputs=pos_response,
is_training=True)
# Decode
probs = self.decode(response, state, is_infer=False)
outputs.update({"probs": probs})
# Metrics
metrics = self.collect_metrics(outputs, label)
loss = metrics["loss"]
if self.grad_clip > 0:
fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=self.grad_clip),
program=self.train_program)
self.optimizer.minimize(loss)
self.train_fetch_dict = metrics
with fluid.program_guard(self.eval_program, self.startup_program):
# Encode
outputs, state = self.encode(
post_inputs=post,
response_inputs=pos_response,
is_training=False)
# Decode
probs = self.decode(response, state, is_infer=False)
outputs.update({"probs": probs})
# Metrics
metrics = self.collect_metrics(outputs, label)
self.eval_fetch_dict = metrics
self.eval_program = self.eval_program.clone(for_test=True)
self.infer_program = fluid.Program()
with fluid.program_guard(self.infer_program, self.startup_program):
# Input
post = layers.data(
name="post", shape=[1], lod_level=1, dtype="int64")
response = layers.data(
name="response", shape=[1], lod_level=1, dtype="int64")
init_ids = layers.data(
name="init_ids", shape=[1], lod_level=2, dtype="int64")
post_expand_lod = layers.data(
name="post_expand_lod", shape=[1, -1], dtype="int32")
# Encode
state = self.enumerate_encode(post, post_expand_lod)
# Infer
prediction_ids, prediction_scores = self.generator(self.decode,
state, init_ids)
self.infer_program = self.infer_program.clone(for_test=True)
self.infer_fetch_dict = {
"preds": prediction_ids,
"post": post,
"response": response
}
def train(self, inputs, train_state=None):
return self.execute(
program=self.train_program,
feed=self.set_feed(
inputs, mode="train"),
fetch_dict=self.train_fetch_dict)
def evaluate(self, inputs):
return self.execute(
program=self.eval_program,
feed=self.set_feed(
inputs, mode="evaluate"),
fetch_dict=self.eval_fetch_dict)
def infer(self, inputs):
batch_size = inputs["size"]
result = self.execute(
program=self.infer_program,
feed=self.set_feed(
inputs, mode="infer"),
fetch_dict=self.infer_fetch_dict,
return_numpy=False)
def select_top1_in_beam(T):
lod = T.lod()
lens = T.recursive_sequence_lengths()[-1]
sents = np.split(np.array(T), lod[-1][1:-1])
top1_ids = lod[0][:-1]
data = np.concatenate([sents[i] for i in top1_ids])
recur_lens = [[1 for _ in top1_ids], [lens[i] for i in top1_ids]]
return fluid.create_lod_tensor(data, recur_lens, self.place)
preds = select_top1_in_beam(result["preds"])
lens = preds.recursive_sequence_lengths()
lens[0] = [self.num_mappings] * batch_size
preds.set_recursive_sequence_lengths(lens)
result["preds"] = preds
return result
def set_feed(self, inputs, mode="train"):
feed = {}
feed["post"] = inputs["post"]
feed["response"] = inputs["response"]
if mode == "infer":
start_id = self.generator.start_id
batch_size = inputs["size"]
batch_size = batch_size * self.num_mappings
init_ids_data = np.array(
[[start_id] for _ in range(batch_size)], dtype='int64')
init_recursive_seq_lens = [[1] * batch_size, [1] * batch_size]
init_ids = fluid.create_lod_tensor(
init_ids_data, init_recursive_seq_lens, self.place)
feed["init_ids"] = init_ids
post_lens = inputs["post"].recursive_sequence_lengths()[0]
post_expand_lens = [
l for l in post_lens for _ in range(self.num_mappings)
]
post_expand_lens.insert(0, 0)
post_expand_lod = np.cumsum(post_expand_lens)[None, :]
post_expand_lod = post_expand_lod.astype("int32")
feed["post_expand_lod"] = post_expand_lod
else:
feed["label"] = inputs["label"]
feed["pos_response"] = sequence_but(
inputs["response"], self.place, position="first")
return feed
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import paddle.fluid as fluid
class Model(object):
def __init__(self, use_gpu=False):
self.train_program = None
self.eval_program = None
self.infer_program = None
self.startup_program = None
self.train_fetch_dict = None
self.eval_fetch_dict = None
self.infer_fetch_dict = None
self.build_program()
assert self.startup_program is not None
assert self.train_program is not None
assert self.train_fetch_dict is not None
assert self.eval_program is not None
assert self.eval_fetch_dict is not None
self.place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
self.executor = fluid.Executor(self.place)
self.executor.run(self.startup_program)
def build_program(self):
raise NotImplementedError
def set_feed(self, inputs, mode):
raise NotImplementedError
def train(self, inputs):
raise NotImplementedError
def evaluate(self, inputs):
raise NotImplementedError
def infer(self, inputs):
raise NotImplementedError
def execute(self, program, feed, fetch_dict, return_numpy=True):
fetch_keys = list(fetch_dict.keys())
fetch_list = list(fetch_dict.values())
fetch_vals = self.executor.run(program=program,
feed=feed,
fetch_list=fetch_list,
return_numpy=return_numpy)
return dict(zip(fetch_keys, fetch_vals))
def save(self, model_dir):
""" Save model parameters. """
fluid.io.save_persistables(
executor=self.executor,
dirname=model_dir,
main_program=self.train_program)
def load(self, model_dir):
""" Load model parameters. """
fluid.io.load_persistables(
executor=self.executor,
dirname=model_dir,
main_program=self.train_program)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import mmpms.layers as layers
class Attention(object):
def __init__(self, mode="mlp", memory_dim=None, hidden_dim=None, name=None):
assert (mode in ["dot", "general", "mlp"]), (
"Unsupported attention mode: {}".format(mode))
self.name = name or "Attention"
self.mode = mode
if mode == "general":
self.query_fc = layers.FC(size=memory_dim,
bias_attr=False,
name="{}.query".format(self.name))
self.memory_dim = memory_dim
elif mode == "mlp":
assert hidden_dim is not None
self.query_fc = layers.FC(size=hidden_dim,
bias_attr=False,
name="{}.query".format(self.name))
self.memory_fc = layers.FC(size=hidden_dim,
name="{}.memory".format(self.name))
self.out_fc = layers.FC(size=1,
bias_attr=False,
name="{}.out".format(self.name))
def __call__(self, query, memory, memory_proj=None):
if self.mode == "dot":
assert query.shape[-1] == memory.shape[-1]
query_expand = layers.sequence_expand_as(x=query, y=memory)
attn = layers.reduce_sum(
layers.elementwise_mul(
x=query_expand, y=memory),
dim=-1,
keep_dim=True)
elif self.mode == "general":
assert self.memory_dim == memory.shape[-1]
query_proj = self.query_fc(query)
query_proj_expand = layers.sequence_expand_as(
x=query_proj, y=memory)
attn = layers.reduce_sum(
layers.elementwise_mul(
x=query_proj_expand, y=memory),
dim=-1,
keep_dim=True)
else:
if memory_proj is None:
memory_proj = self.memory_fc(memory)
query_proj = self.query_fc(query)
query_proj_expand = layers.sequence_expand_as(
x=query_proj, y=memory_proj)
hidden = layers.tanh(query_proj_expand + memory_proj)
attn = self.out_fc(hidden)
weights = layers.sequence_softmax(input=attn, use_cudnn=False)
weights_reshape = layers.reshape(x=weights, shape=[-1])
scaled = layers.elementwise_mul(x=memory, y=weights_reshape, axis=0)
weighted_memory = layers.sequence_pool(input=scaled, pool_type="sum")
return weighted_memory, weights
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import mmpms.layers as layers
from mmpms.modules.attention import Attention
class BaseDecoder(object):
def step(self, input, state):
""" step function """
raise NotImplementedError
def forward(self, input, state):
""" forward function """
drnn = layers.DynamicRNN()
def memory(memory_state):
if isinstance(memory_state, dict):
return {k: memory(v) for k, v in memory_state.items()}
elif isinstance(memory_state, (tuple, list)):
return type(memory_state)(memory(x) for x in memory_state)
else:
return drnn.memory(init=memory_state, need_reorder=True)
def update(pre_state, new_state):
if isinstance(new_state, dict):
for k in new_state.keys():
if k in pre_state:
update(pre_state[k], new_state[k])
elif isinstance(new_state, (tuple, list)):
for i in range(len(new_state)):
update(pre_state[i], new_state[i])
else:
drnn.update_memory(pre_state, new_state)
with drnn.block():
current_input = drnn.step_input(input)
pre_state = memory(state)
output, current_state = self.step(current_input, pre_state)
update(pre_state, current_state)
drnn.output(output)
rnn_output = drnn()
return rnn_output
def __call__(self, input, state):
return self.forward(input, state)
class GRUDecoder(BaseDecoder):
def __init__(self,
hidden_dim,
num_layers=1,
attn_mode="none",
attn_hidden_dim=None,
memory_dim=None,
dropout=0.0,
name=None):
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.attn_mode = None if attn_mode == "none" else attn_mode
self.attn_hidden_dim = attn_hidden_dim or hidden_dim // 2
self.memory_dim = memory_dim or hidden_dim
self.dropout = dropout
self.rnn = layers.StackedGRUCell(
hidden_dim=hidden_dim,
num_layers=num_layers,
dropout=dropout if self.num_layers > 1 else 0.0,
name=name)
if self.attn_mode:
self.attention = Attention(
mode=self.attn_mode,
memory_dim=self.memory_dim,
hidden_dim=self.attn_hidden_dim)
def step(self, input, state):
hidden = state["hidden"]
rnn_input_list = [input]
if self.attn_mode:
memory = state["memory"]
memory_proj = state.get("memory_proj")
query = hidden[-1]
context, _ = self.attention(
query=query, memory=memory, memory_proj=memory_proj)
rnn_input_list.append(context)
rnn_input = layers.concat(rnn_input_list, axis=1)
rnn_output, new_hidden = self.rnn(rnn_input, hidden)
new_state = {k: v for k, v in state.items() if k != "hidden"}
new_state["hidden"] = new_hidden
if self.attn_mode:
output = layers.concat([rnn_output, context], axis=1)
else:
output = rnn_output
return output, new_state
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from __future__ import division
import numpy as np
import paddle.fluid as fluid
import mmpms.layers as layers
class Embedder(layers.Embedding):
def __init__(self,
num_embeddings,
embedding_dim,
is_sparse=False,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32',
name=None):
super(Embedder, self).__init__(
size=[num_embeddings, embedding_dim],
is_sparse=is_sparse,
is_distributed=is_distributed,
padding_idx=padding_idx,
param_attr=param_attr,
dtype=dtype,
name=name)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
def from_pretrained(self, embeds, place, scale=0.05):
assert len(embeds) == self.num_embeddings
assert len(embeds[0]) == self.embedding_dim
embeds = np.array(embeds, dtype='float32')
num_known = 0
for i in range(len(embeds)):
if np.all(embeds[i] == 0):
embeds[i] = np.random.uniform(
low=-scale, high=scale, size=self.embedding_dim)
else:
num_known += 1
if self.padding_idx is not None:
embeds[self.padding_idx] = 0
embedding_param = fluid.global_scope().find_var(
self.param_attr.name).get_tensor()
embedding_param.set(embeds, place)
print("{} words have pretrained embeddings ".format(num_known) +
"(coverage: {:.3f})".format(num_known / self.num_embeddings))
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import mmpms.layers as layers
class GRUEncoder(object):
def __init__(self,
hidden_dim,
num_layers=1,
bidirectional=True,
dropout=0.0,
name=None):
num_directions = 2 if bidirectional else 1
assert hidden_dim % num_directions == 0
rnn_hidden_dim = hidden_dim // num_directions
self.hidden_dim = hidden_dim
self.rnn_hidden_dim = rnn_hidden_dim
self.num_layers = num_layers
self.bidirectional = bidirectional
self.dropout = dropout
self.gru = layers.GRU(hidden_dim=rnn_hidden_dim,
num_layers=num_layers,
bidirectional=bidirectional,
dropout=dropout if self.num_layers > 1 else 0.0,
name=name)
def __call__(self, inputs, hidden=None):
outputs, new_hidden = self.gru(inputs, hidden)
return outputs, new_hidden
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import numpy as np
import paddle.fluid as fluid
import mmpms.layers as layers
def state_assign(new_state, old_state):
if isinstance(new_state, dict):
for k in new_state.keys():
state_assign(new_state[k], old_state[k])
elif isinstance(new_state, (tuple, list)):
assert len(new_state) == len(old_state)
for new_s, old_s in zip(new_state, old_state):
state_assign(new_s, old_s)
else:
layers.assign(new_state, old_state)
def state_sequence_expand(state, y):
if isinstance(state, dict):
return {k: state_sequence_expand(v, y) for k, v in state.items()}
elif isinstance(state, (tuple, list)):
return type(state)(state_sequence_expand(s, y) for s in state)
else:
if state.dtype != y.dtype:
return layers.sequence_expand(state, layers.cast(y, state.dtype))
else:
return layers.sequence_expand(state, y)
class BeamSearch(object):
def __init__(self,
vocab_size,
beam_size,
start_id,
end_id,
unk_id,
min_length=1,
max_length=30,
length_average=False,
ignore_unk=False,
ignore_repeat=False):
self.vocab_size = vocab_size
self.beam_size = beam_size
self.start_id = start_id
self.end_id = end_id
self.unk_id = unk_id
self.min_length = min_length
self.max_length = max_length
self.length_average = length_average
self.ignore_unk = ignore_unk
self.ignore_repeat = ignore_repeat
def __call__(self, step_fn, state, init_ids):
init_scores = layers.fill_constant_batch_size_like(
input=init_ids, shape=[-1, 1], dtype="float32", value=0)
init_scores = layers.lod_reset(init_scores, init_ids)
unk_scores = np.zeros(self.vocab_size, dtype="float32")
unk_scores[self.unk_id] = -1e9
unk_scores = layers.assign(unk_scores)
end_scores = np.zeros(self.vocab_size, dtype="float32")
end_scores[self.end_id] = -1e9
end_scores = layers.assign(end_scores)
array_len = layers.fill_constant(
shape=[1], dtype="int64", value=self.max_length)
min_array_len = layers.fill_constant(
shape=[1], dtype="int64", value=self.min_length)
counter = layers.zeros(shape=[1], dtype="int64", force_cpu=True)
# ids, scores as memory
ids_array = layers.create_array("int64")
scores_array = layers.create_array("float32")
layers.array_write(init_ids, array=ids_array, i=counter)
layers.array_write(init_scores, array=scores_array, i=counter)
cond = layers.less_than(x=counter, y=array_len)
while_op = layers.While(cond=cond)
with while_op.block():
pre_ids = layers.array_read(array=ids_array, i=counter)
pre_score = layers.array_read(array=scores_array, i=counter)
# use step_fn to update state and get score
score, new_state = step_fn(pre_ids, state)
score = layers.log(score)
if self.ignore_unk:
score = score + unk_scores
if self.ignore_repeat:
repeat_scores = layers.cast(
layers.one_hot(pre_ids, self.vocab_size), "float32") * -1e9
score = score + repeat_scores
min_cond = layers.less_than(x=counter, y=min_array_len)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(min_cond):
layers.assign(score + end_scores, score)
score = layers.lod_reset(x=score, y=pre_score)
topk_scores, topk_indices = layers.topk(score, k=self.beam_size)
if self.length_average:
pre_num = layers.cast(counter, "float32")
cur_num = layers.increment(pre_num, value=1.0, in_place=False)
accu_scores = layers.elementwise_add(
x=layers.elementwise_div(topk_scores, cur_num),
y=layers.elementwise_div(
layers.elementwise_mul(
layers.reshape(
pre_score, shape=[-1]), pre_num),
cur_num),
axis=0)
else:
accu_scores = layers.elementwise_add(
x=topk_scores,
y=layers.reshape(
pre_score, shape=[-1]),
axis=0)
selected_ids, selected_scores, parent_idx = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_score,
ids=topk_indices,
scores=accu_scores,
beam_size=self.beam_size,
end_id=self.end_id,
return_parent_idx=True)
layers.increment(x=counter, value=1, in_place=True)
# update the memories
layers.array_write(selected_ids, array=ids_array, i=counter)
layers.array_write(selected_scores, array=scores_array, i=counter)
state_assign(new_state, state)
length_cond = layers.less_than(x=counter, y=array_len)
not_finish_cond = layers.logical_not(
layers.is_empty(x=selected_ids))
layers.logical_and(x=length_cond, y=not_finish_cond, out=cond)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(not_finish_cond):
new_state = state_sequence_expand(new_state,
selected_scores)
state_assign(new_state, state)
prediction_ids, prediction_scores = layers.beam_search_decode(
ids=ids_array,
scores=scores_array,
beam_size=self.beam_size,
end_id=self.end_id)
return prediction_ids, prediction_scores
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import codecs
import json
import argparse
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Unsupported value encountered.')
class HParams(dict):
def __getattr__(self, name):
if name in self.keys():
return self[name]
else:
for v in self.values():
if isinstance(v, HParams):
if name in v:
return v[name]
raise AttributeError("'HParams' object has no attribute '{}'".format(
name))
return None
def __setattr__(self, name, value):
self[name] = value
def save(self, filename):
with codecs.open(filename, "w", encoding="utf-8") as fp:
json.dump(self, fp, ensure_ascii=False, indent=4, sort_keys=False)
def load(self, filename):
with codecs.open(filename, "r", encoding="utf-8") as fp:
params_dict = json.load(fp)
for k, v in params_dict.items():
# Only load grouping hyperparameters
if isinstance(v, dict):
self[k] = HParams(v)
def parse_args(parser):
parsed = parser.parse_args()
args = HParams()
optional_args = parser._action_groups[1]
for action in optional_args._group_actions[1:]:
arg_name = action.dest
args[arg_name] = getattr(parsed, arg_name)
for group in parser._action_groups[2:]:
group_args = HParams()
for action in group._group_actions:
arg_name = action.dest
group_args[arg_name] = getattr(parsed, arg_name)
if len(group_args) > 0:
args[group.title] = group_args
return args
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from __future__ import absolute_import
import sys
import logging
def getLogger(log_path, name=None):
logger = logging.getLogger(name)
logger.propagate = False
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("%(message)s")
sh = logging.StreamHandler(sys.stdout)
sh.setFormatter(formatter)
logger.addHandler(sh)
fh = logging.FileHandler(log_path, mode='w')
fh.setFormatter(formatter)
logger.addHandler(fh)
return logger
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from collections import Counter
from nltk.translate import bleu_score
from nltk.translate.bleu_score import SmoothingFunction
class Metric(object):
def __init__(self):
self.reset()
def update(self, val, num=1):
self.val = float(val)
self.num += num
p = num / self.num
self.avg = self.val * p + self.avg * (1 - p)
def reset(self):
self.val = 0
self.avg = 0
self.num = 0
def __repr__(self):
return "Metric(val={}, avg={}, num={})".format(self.val, self.avg,
self.num)
def state_dict(self):
return {"val": self.val, "avg": self.avg, "num": self.num}
def distinct(seqs):
batch_size = len(seqs)
unigrams_all, bigrams_all = Counter(), Counter()
for seq in seqs:
unigrams = Counter(seq)
bigrams = Counter(zip(seq, seq[1:]))
unigrams_all.update(unigrams)
bigrams_all.update(bigrams)
dist_1 = (len(unigrams_all) + 1e-12) / (sum(unigrams_all.values()) + 1e-5)
dist_2 = (len(bigrams_all) + 1e-12) / (sum(bigrams_all.values()) + 1e-5)
return dist_1, dist_2
def bleu(hyps, refs):
bleu_1 = []
bleu_2 = []
for hyp, ref in zip(hyps, refs):
try:
score = bleu_score.sentence_bleu(
[ref],
hyp,
smoothing_function=SmoothingFunction().method7,
weights=[1, 0, 0, 0])
except:
score = 0
bleu_1.append(score)
try:
score = bleu_score.sentence_bleu(
[ref],
hyp,
smoothing_function=SmoothingFunction().method7,
weights=[0.5, 0.5, 0, 0])
except:
score = 0
bleu_2.append(score)
bleu_1 = sum(bleu_1) / len(bleu_1)
bleu_2 = sum(bleu_2) / len(bleu_2)
return bleu_1, bleu_2
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import numpy as np
import paddle.fluid as fluid
def tensor2list(T):
lod = T.lod()
array = np.array(T)
if array.shape[-1] == 1:
array = array.squeeze(-1)
array = array.tolist()
for lod_i in lod[::-1]:
array = [array[start:end] for start, end in zip(lod_i, lod_i[1:])]
return array
def sequence_last(T, place):
lod = T.lod()[-1]
recursive_seq_lens = T.recursive_sequence_lengths()
array = np.array(T)
last_ids = np.array(lod[1:]) - 1
data = array[last_ids]
return fluid.create_lod_tensor(data, recursive_seq_lens[:-1], place)
def sequence_but(T, place, position="first"):
assert position in ["first", "last"]
lod = T.lod()[-1][1:-1]
recursive_seq_lens = T.recursive_sequence_lengths()
array = np.array(T)
if position == "first":
data = np.concatenate([a[1:] for a in np.split(array, lod)])
else:
data = np.concatenate([a[:-1] for a in np.split(array, lod)])
recursive_seq_lens[-1] = [l - 1 for l in recursive_seq_lens[-1]]
return fluid.create_lod_tensor(data, recursive_seq_lens, place)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
import os
import argparse
from mmpms.inputters.dataset import PostResponseDataset
parser = argparse.ArgumentParser()
parser.add_argument("--data_dir", type=str, default="./data/")
parser.add_argument(
"--embed_file", type=str, default="./data/glove.840B.300d.txt")
parser.add_argument("--max_vocab_size", type=int, default=30000)
parser.add_argument("--min_len", type=int, default=3)
parser.add_argument("--max_len", type=int, default=30)
args = parser.parse_args()
vocab_file = os.path.join(args.data_dir, "vocab.json")
raw_train_file = os.path.join(args.data_dir, "dial.train")
raw_valid_file = os.path.join(args.data_dir, "dial.valid")
raw_test_file = os.path.join(args.data_dir, "dial.test")
train_file = raw_train_file + ".pkl"
valid_file = raw_valid_file + ".pkl"
test_file = raw_test_file + ".pkl"
dataset = PostResponseDataset(
max_vocab_size=args.max_vocab_size,
min_len=args.min_len,
max_len=args.max_len,
embed_file=args.embed_file)
# Build vocabulary
dataset.build_vocab(raw_train_file)
dataset.save_vocab(vocab_file)
# Build examples
valid_examples = dataset.build_examples(raw_valid_file)
dataset.save_examples(valid_examples, valid_file)
test_examples = dataset.build_examples(raw_test_file)
dataset.save_examples(test_examples, test_file)
train_examples = dataset.build_examples(raw_train_file)
dataset.save_examples(train_examples, train_file)
#!/usr/bin/env python
# -*- coding: utf-8 -*-
################################################################################
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
################################################################################
from __future__ import absolute_import
import os
import argparse
from datetime import datetime
from mmpms.utils.args import str2bool
from mmpms.utils.args import parse_args
from mmpms.utils.logging import getLogger
from mmpms.utils.misc import tensor2list
from mmpms.inputters.dataset import PostResponseDataset
from mmpms.inputters.dataloader import DataLoader
from mmpms.models.mmpms import MMPMS
from mmpms.modules.generator import BeamSearch
from mmpms.engine import Engine
from mmpms.engine import evaluate
from mmpms.engine import infer
parser = argparse.ArgumentParser()
parser.add_argument("--args_file", type=str, default=None)
parser.add_argument("--use_gpu", type=str2bool, default=True)
parser.add_argument("--model_dir", type=str, default=None)
parser.add_argument("--eval", action="store_true")
parser.add_argument("--infer", action="store_true")
# Data
data_arg = parser.add_argument_group("Data")
data_arg.add_argument("--data_dir", type=str, default="./data/")
data_arg.add_argument("--vocab_file", type=str, default=None)
data_arg.add_argument("--train_file", type=str, default=None)
data_arg.add_argument("--valid_file", type=str, default=None)
data_arg.add_argument("--test_file", type=str, default=None)
parser.add_argument(
"--embed_file", type=str, default="./data/glove.840B.300d.txt")
data_arg.add_argument("--max_vocab_size", type=int, default=30000)
data_arg.add_argument("--min_len", type=int, default=3)
data_arg.add_argument("--max_len", type=int, default=30)
# Model
model_arg = parser.add_argument_group("Model")
model_arg.add_argument("--embed_dim", type=int, default=300)
model_arg.add_argument("--hidden_dim", type=int, default=1024)
model_arg.add_argument("--num_mappings", type=int, default=20)
model_arg.add_argument("--tau", type=float, default=0.67)
model_arg.add_argument("--num_layers", type=int, default=1)
model_arg.add_argument("--bidirectional", type=str2bool, default=True)
model_arg.add_argument(
"--attn_mode",
type=str,
default='mlp',
choices=['none', 'mlp', 'dot', 'general'])
model_arg.add_argument(
"--use_pretrained_embedding", type=str2bool, default=True)
model_arg.add_argument("--embed_init_scale", type=float, default=0.03)
model_arg.add_argument("--dropout", type=float, default=0.3)
# Training
train_arg = parser.add_argument_group("Train")
train_arg.add_argument("--save_dir", type=str, default="./output/")
train_arg.add_argument("--num_epochs", type=int, default=10)
train_arg.add_argument("--shuffle", type=str2bool, default=True)
train_arg.add_argument("--log_steps", type=int, default=100)
train_arg.add_argument("--valid_steps", type=int, default=500)
train_arg.add_argument("--batch_size", type=int, default=128)
# Optimization
optim_arg = parser.add_argument_group("Optim")
optim_arg.add_argument("--optimizer", type=str, default="Adam")
optim_arg.add_argument("--lr", type=float, default=0.0002)
optim_arg.add_argument("--grad_clip", type=float, default=5.0)
# Inference
infer_arg = parser.add_argument_group("Inference")
infer_arg.add_argument("--beam_size", type=int, default=10)
infer_arg.add_argument("--min_infer_len", type=int, default=3)
infer_arg.add_argument("--max_infer_len", type=int, default=30)
infer_arg.add_argument("--length_average", type=str2bool, default=False)
infer_arg.add_argument("--ignore_unk", type=str2bool, default=True)
infer_arg.add_argument("--ignore_repeat", type=str2bool, default=True)
infer_arg.add_argument("--infer_batch_size", type=int, default=64)
infer_arg.add_argument("--result_file", type=str, default="./infer.result")
def main():
args = parse_args(parser)
if args.args_file:
args.load(args.args_file)
print("Loaded args from '{}'".format(args.args_file))
args.Data.vocab_file = args.Data.vocab_file or os.path.join(
args.Data.data_dir, "vocab.json")
args.Data.train_file = args.Data.train_file or os.path.join(
args.Data.data_dir, "dial.train.pkl")
args.Data.valid_file = args.Data.valid_file or os.path.join(
args.Data.data_dir, "dial.valid.pkl")
args.Data.test_file = args.Data.test_file or os.path.join(
args.Data.data_dir, "dial.test.pkl")
print("Args:")
print(args)
print()
# Dataset Definition
dataset = PostResponseDataset(
max_vocab_size=args.max_vocab_size,
min_len=args.min_len,
max_len=args.max_len,
embed_file=args.embed_file)
dataset.load_vocab(args.vocab_file)
# Generator Definition
generator = BeamSearch(
vocab_size=dataset.vocab.size(),
beam_size=args.beam_size,
start_id=dataset.vocab.bos_id,
end_id=dataset.vocab.eos_id,
unk_id=dataset.vocab.unk_id,
min_length=args.min_infer_len,
max_length=args.max_infer_len,
length_average=args.length_average,
ignore_unk=args.ignore_unk,
ignore_repeat=args.ignore_repeat)
# Model Definition
model = MMPMS(
vocab=dataset.vocab,
generator=generator,
hparams=args.Model,
optim_hparams=args.Optim,
use_gpu=args.use_gpu)
infer_parse_dict = {
"post": lambda T: dataset.denumericalize(tensor2list(T)),
"response": lambda T: dataset.denumericalize(tensor2list(T)),
"preds": lambda T: dataset.denumericalize(tensor2list(T)),
}
if args.infer:
if args.model_dir is not None:
model.load(args.model_dir)
print("Loaded model checkpoint from '{}'".format(args.model_dir))
infer_data = dataset.load_examples(args.test_file)
infer_loader = DataLoader(
data=infer_data,
batch_size=args.infer_batch_size,
shuffle=False,
use_gpu=args.use_gpu)
print("Inference starts ...")
infer_results = infer(
model, infer_loader, infer_parse_dict, save_file=args.result_file)
elif args.eval:
if args.model_dir is not None:
model.load(args.model_dir)
print("Loaded model checkpoint from '{}'".format(args.model_dir))
eval_data = dataset.load_examples(args.test_file)
eval_loader = DataLoader(
data=eval_data,
batch_size=args.batch_size,
shuffle=False,
use_gpu=args.use_gpu)
print("Evaluation starts ...")
eval_metrics_tracker = evaluate(model, eval_loader)
print(" ".join("{}-{}".format(name.upper(), value.avg)
for name, value in eval_metrics_tracker.items()))
else:
valid_data = dataset.load_examples(args.valid_file)
valid_loader = DataLoader(
data=valid_data,
batch_size=args.batch_size,
shuffle=False,
use_gpu=args.use_gpu)
train_data = dataset.load_examples(args.train_file)
train_loader = DataLoader(
data=train_data,
batch_size=args.batch_size,
shuffle=args.shuffle,
use_gpu=args.use_gpu)
# Save Directory Definition
date_str, time_str = datetime.now().strftime("%Y%m%d-%H%M%S").split("-")
result_str = "{}-{}".format(model.__class__.__name__, time_str)
args.save_dir = os.path.join(args.save_dir, date_str, result_str)
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# Logger Definition
logger = getLogger(
os.path.join(args.save_dir, "train.log"), name="mmpms")
# Save args
args_file = os.path.join(args.save_dir, "args.json")
args.save(args_file)
logger.info("Saved args to '{}'".format(args_file))
# Executor Definition
exe = Engine(
model=model,
save_dir=args.save_dir,
log_steps=args.log_steps,
valid_steps=args.valid_steps,
logger=logger)
if args.model_dir is not None:
exe.load(args.model_dir)
# Train
logger.info("Training starts ...")
exe.evaluate(valid_loader, is_save=False)
for epoch in range(args.num_epochs):
exe.train_epoch(train_iter=train_loader, valid_iter=valid_loader)
logger.info("Training done!")
if __name__ == '__main__':
try:
main()
except KeyboardInterrupt:
print("\nExited from the program ealier!")
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册