beam_search.py 6.0 KB
Newer Older
C
caoying03 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
import os
import math
import numpy as np

import paddle.v2 as paddle

from utils import logger, load_reverse_dict

__all__ = ["BeamSearch"]


class BeamSearch(object):
    """
14
    Generating sequence by beam search
C
caoying03 已提交
15 16 17 18 19 20 21
    NOTE: this class only implements generating one sentence at a time.
    """

    def __init__(self, inferer, word_dict_file, beam_size=1, max_gen_len=100):
        """
        constructor method.

22 23
        :param inferer: object of paddle.Inference that represents the entire
            network to forward compute the test batch
C
caoying03 已提交
24 25 26 27 28
        :type inferer: paddle.Inference
        :param word_dict_file: path of word dictionary file
        :type word_dict_file: str
        :param beam_size: expansion width in each iteration
        :type param beam_size: int
29
        :param max_gen_len: the maximum number of iterations
C
caoying03 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43
        :type max_gen_len: int
        """
        self.inferer = inferer
        self.beam_size = beam_size
        self.max_gen_len = max_gen_len
        self.ids_2_word = load_reverse_dict(word_dict_file)
        logger.info("dictionay len = %d" % (len(self.ids_2_word)))

        try:
            self.eos_id = next(x[0] for x in self.ids_2_word.iteritems()
                               if x[1] == "<e>")
            self.unk_id = next(x[0] for x in self.ids_2_word.iteritems()
                               if x[1] == "<unk>")
        except StopIteration:
44
            logger.fatal(("the word dictionay must contain an ending mark "
C
caoying03 已提交
45 46 47 48 49 50 51 52
                          "in the text generation task."))

        self.candidate_paths = []
        self.final_paths = []

    def _top_k(self, softmax_out, k):
        """
        get indices of the words with k highest probablities.
53
        NOTE: <unk> will be excluded if it is among the top k words, then word
C
caoying03 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71
        with (k + 1)th highest probability will be returned.

        :param softmax_out: probablity over the dictionary
        :type softmax_out: narray
        :param k: number of word indices to return
        :type k: int
        :return: indices of k words with highest probablities.
        :rtype: list
        """
        ids = softmax_out.argsort()[::-1]
        return ids[ids != self.unk_id][:k]

    def _forward_batch(self, batch):
        """
        forward a test batch.

        :params batch: the input data batch
        :type batch: list
72
        :return: probablities of the predicted word
C
caoying03 已提交
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
        :rtype: ndarray
        """
        return self.inferer.infer(input=batch, field=["value"])

    def _beam_expand(self, next_word_prob):
        """
        In every iteration step, the model predicts the possible next words.
        For each input sentence, the top k words is added to end of the original
        sentence to form a new generated sentence.

        :param next_word_prob: probablities of the next words
        :type next_word_prob: ndarray
        :return: the expanded new sentences.
        :rtype: list
        """
        assert len(next_word_prob) == len(self.candidate_paths), (
            "Wrong forward computing results!")
        top_beam_words = np.apply_along_axis(self._top_k, 1, next_word_prob,
                                             self.beam_size)
        new_paths = []
        for i, words in enumerate(top_beam_words):
            old_path = self.candidate_paths[i]
            for w in words:
                log_prob = old_path["log_prob"] + math.log(next_word_prob[i][w])
                gen_ids = old_path["ids"] + [w]
                if w == self.eos_id:
                    self.final_paths.append({
                        "log_prob": log_prob,
                        "ids": gen_ids
                    })
                else:
                    new_paths.append({"log_prob": log_prob, "ids": gen_ids})
        return new_paths

    def _beam_shrink(self, new_paths):
        """
        to return the top beam_size generated sequences with the highest
        probabilities at the end of evey generation iteration.

        :param new_paths: all possible generated sentences
        :type new_paths: list
        :return: a state flag to indicate whether to stop beam search
        :rtype: bool
        """

        if len(self.final_paths) >= self.beam_size:
            max_candidate_log_prob = max(
                new_paths, key=lambda x: x["log_prob"])["log_prob"]
            min_complete_path_log_prob = min(
                self.final_paths, key=lambda x: x["log_prob"])["log_prob"]
            if min_complete_path_log_prob >= max_candidate_log_prob:
                return True

        new_paths.sort(key=lambda x: x["log_prob"], reverse=True)
        self.candidate_paths = new_paths[:self.beam_size]
        return False

    def gen_a_sentence(self, input_sentence):
        """
        generating sequence for an given input

        :param input_sentence: one input_sentence
        :type input_sentence: list
        :return: the generated word sequences
        :rtype: list
        """
        self.candidate_paths = [{"log_prob": 0., "ids": input_sentence}]
        input_len = len(input_sentence)

        for i in range(self.max_gen_len):
            next_word_prob = self._forward_batch(
                [[x["ids"]] for x in self.candidate_paths])
            new_paths = self._beam_expand(next_word_prob)

            min_candidate_log_prob = min(
                new_paths, key=lambda x: x["log_prob"])["log_prob"]

            path_to_remove = [
                path for path in self.final_paths
                if path["log_prob"] < min_candidate_log_prob
            ]
            for p in path_to_remove:
                self.final_paths.remove(p)

            if self._beam_shrink(new_paths):
                self.candidate_paths = []
                break

        gen_ids = sorted(
            self.final_paths + self.candidate_paths,
            key=lambda x: x["log_prob"],
            reverse=True)[:self.beam_size]
        self.final_paths = []

        def _to_str(x):
            text = " ".join(self.ids_2_word[idx]
                            for idx in x["ids"][input_len:])
            return "%.4f\t%s" % (x["log_prob"], text)

        return map(_to_str, gen_ids)