beamsearch.py 6.9 KB
Newer Older
R
ranqiu 已提交
1 2 3 4
#coding=utf-8

import sys
import time
R
ranqiu 已提交
5
import math
R
ranqiu 已提交
6 7
import numpy as np

R
ranqiu 已提交
8 9
import reader

R
ranqiu 已提交
10 11 12 13 14 15 16 17 18 19 20

class BeamSearch(object):
    """
    Generate sequence by beam search
    """

    def __init__(self,
                 inferer,
                 trg_dict,
                 pos_size,
                 padding_num,
R
ranqiu 已提交
21
                 batch_size=1,
R
ranqiu 已提交
22 23 24 25
                 beam_size=1,
                 max_len=100):
        self.inferer = inferer
        self.trg_dict = trg_dict
R
ranqiu 已提交
26
        self.reverse_trg_dict = reader.get_reverse_dict(trg_dict)
R
ranqiu 已提交
27 28 29 30 31 32
        self.word_padding = trg_dict.__len__()
        self.pos_size = pos_size
        self.pos_padding = pos_size
        self.padding_num = padding_num
        self.win_len = padding_num + 1
        self.max_len = max_len
R
ranqiu 已提交
33
        self.batch_size = batch_size
R
ranqiu 已提交
34 35
        self.beam_size = beam_size

R
ranqiu 已提交
36
    def get_beam_input(self, batch, sample_list):
R
ranqiu 已提交
37 38 39 40 41
        """
        Get input for generation at the current iteration.
        """
        beam_input = []

R
ranqiu 已提交
42 43 44
        for sample_id in sample_list:
            for path in self.candidate_path[sample_id]:
                if len(path['seq']) < self.win_len:
45 46 47 48 49 50
                    cur_trg = [self.word_padding] * (
                        self.win_len - len(path['seq']) - 1
                    ) + [self.trg_dict['<s>']] + path['seq']
                    cur_trg_pos = [self.pos_padding] * (
                        self.win_len - len(path['seq']) - 1) + [0] + range(
                            1, len(path['seq']) + 1)
R
ranqiu 已提交
51
                else:
R
ranqiu 已提交
52
                    cur_trg = path['seq'][-self.win_len:]
R
ranqiu 已提交
53
                    cur_trg_pos = range(
R
ranqiu 已提交
54 55 56 57
                        len(path['seq']) + 1 - self.win_len,
                        len(path['seq']) + 1)

                beam_input.append(batch[sample_id] + [cur_trg] + [cur_trg_pos])
R
ranqiu 已提交
58 59 60 61 62 63 64 65 66 67 68

        return beam_input

    def get_prob(self, beam_input):
        """
        Get the probabilities of all possible tokens.
        """
        row_list = [j * self.win_len for j in range(len(beam_input))]
        prob = self.inferer.infer(beam_input, field='value')[row_list, :]
        return prob

R
ranqiu 已提交
69
    def _top_k(self, prob, k):
R
ranqiu 已提交
70
        """
R
ranqiu 已提交
71
        Get indices of the words with k highest probablities.
R
ranqiu 已提交
72
        """
R
ranqiu 已提交
73 74 75
        return prob.argsort()[-k:][::-1]

    def beam_expand(self, prob, sample_list):
R
ranqiu 已提交
76
        """
R
ranqiu 已提交
77 78
        In every iteration step, the model predicts the possible next words.
        For each input sentence, the top beam_size words are selected as candidates.
R
ranqiu 已提交
79
        """
R
ranqiu 已提交
80 81 82 83 84 85 86 87 88
        top_words = np.apply_along_axis(self._top_k, 1, prob, self.beam_size)

        candidate_words = [[]] * len(self.candidate_path)
        idx = 0

        for sample_id in sample_list:
            for seq_id, path in enumerate(self.candidate_path[sample_id]):
                for w in top_words[idx, :]:
                    score = path['score'] + math.log(prob[idx, w])
89 90 91 92 93
                    candidate_words[sample_id] = candidate_words[sample_id] + [{
                        'word': w,
                        'score': score,
                        'seq_id': seq_id
                    }]
R
ranqiu 已提交
94 95 96 97 98
                idx = idx + 1

        return candidate_words

    def beam_shrink(self, candidate_words, sample_list):
R
ranqiu 已提交
99
        """
R
ranqiu 已提交
100 101
        Pruning process of the beam search. During the process, beam_size most post possible
        sequences are selected for the beam in the next generation.
R
ranqiu 已提交
102
        """
R
ranqiu 已提交
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142
        new_path = [[]] * len(self.candidate_path)

        for sample_id in sample_list:
            beam_words = sorted(
                candidate_words[sample_id],
                key=lambda x: x['score'],
                reverse=True)[:self.beam_size]

            complete_seq_min_score = None
            complete_path_num = len(self.complete_path[sample_id])

            if complete_path_num > 0:
                complete_seq_min_score = min(self.complete_path[sample_id],
                                             key=lambda x: x['score'])['score']
                if complete_path_num >= self.beam_size:
                    beam_words_max_score = beam_words[0]['score']
                    if beam_words_max_score < complete_seq_min_score:
                        continue

            for w in beam_words:

                if w['word'] == self.trg_dict['<e>']:
                    if complete_path_num < self.beam_size or complete_seq_min_score <= w[
                            'score']:

                        seq = self.candidate_path[sample_id][w['seq_id']]['seq']
                        self.complete_path[sample_id] = self.complete_path[
                            sample_id] + [{
                                'seq': seq,
                                'score': w['score']
                            }]

                        if complete_seq_min_score is None or complete_seq_min_score > w[
                                'score']:
                            complete_seq_min_score = w['score']
                else:
                    seq = self.candidate_path[sample_id][w['seq_id']]['seq'] + [
                        w['word']
                    ]
                    new_path[sample_id] = new_path[sample_id] + [{
143 144
                        'seq': seq,
                        'score': w['score']
R
ranqiu 已提交
145 146 147 148 149 150 151 152 153 154 155 156
                    }]

        return new_path

    def search_one_batch(self, batch):
        """
        Perform beam search on one mini-batch.
        """
        real_size = len(batch)
        self.candidate_path = [[{'seq': [], 'score': 0.}]] * real_size
        self.complete_path = [[]] * real_size
        sample_list = range(real_size)
R
ranqiu 已提交
157 158

        for i in xrange(self.max_len):
R
ranqiu 已提交
159
            beam_input = self.get_beam_input(batch, sample_list)
R
ranqiu 已提交
160 161
            prob = self.get_prob(beam_input)

R
ranqiu 已提交
162 163 164 165 166 167 168
            candidate_words = self.beam_expand(prob, sample_list)
            new_path = self.beam_shrink(candidate_words, sample_list)
            self.candidate_path = new_path
            sample_list = [
                sample_id for sample_id in sample_list
                if len(new_path[sample_id]) > 0
            ]
R
ranqiu 已提交
169

R
ranqiu 已提交
170
            if len(sample_list) == 0:
R
ranqiu 已提交
171
                break
R
ranqiu 已提交
172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197

        final_path = []
        for i in xrange(real_size):
            top_path = sorted(
                self.complete_path[i] + self.candidate_path[i],
                key=lambda x: x['score'],
                reverse=True)[:self.beam_size]
            final_path.append(top_path)
        return final_path

    def search(self, infer_data):
        """
        Perform beam search on all data.
        """

        def _to_sentence(seq):
            raw_sentence = [self.reverse_trg_dict[id] for id in seq]
            sentence = " ".join(raw_sentence)
            return sentence

        for pos in xrange(0, len(infer_data), self.batch_size):
            batch = infer_data[pos:min(pos + self.batch_size, len(infer_data))]
            self.final_path = self.search_one_batch(batch)
            for top_path in self.final_path:
                print _to_sentence(top_path[0]['seq'])
            sys.stdout.flush()