beam_search_api.py 14.0 KB
Newer Older
1 2 3
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.framework import Variable
Y
yangyaming 已提交
4
import contextlib
5 6
from paddle.fluid.layer_helper import LayerHelper, unique_name
import paddle.fluid.core as core
Y
yangyaming 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30


class DecoderType:
    TRAINING = 1
    BEAM_SEARCH = 2


class InitState(object):
    def __init__(self,
                 init=None,
                 shape=None,
                 value=0.0,
                 need_reorder=False,
                 dtype='float32'):
        self._init = init
        self._shape = shape
        self._value = value
        self._need_reorder = need_reorder
        self._dtype = dtype

    @property
    def value(self):
        return self._init  # may create a LoDTensor

31 32 33 34
    @property
    def need_reorder(self):
        return self._need_reorder

Y
yangyaming 已提交
35 36 37 38 39

class MemoryState(object):
    def __init__(self, state_name, rnn_obj, init_state):
        self._state_name = state_name  # each is a rnn.memory
        self._rnn_obj = rnn_obj
40 41
        self._state_mem = self._rnn_obj.memory(
            init=init_state.value, need_reorder=init_state.need_reorder)
Y
yangyaming 已提交
42 43 44 45 46 47 48 49 50

    def get_state(self):
        return self._state_mem

    def update_state(self, state):
        self._rnn_obj.update_memory(self._state_mem, state)


class ArrayState(object):
51
    def __init__(self, state_name, block, init_state):
Y
yangyaming 已提交
52
        self._state_name = state_name
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self._block = block

        self._state_array = self._block.create_var(
            name=unique_name('array_state_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init_state.value.dtype)

        self._counter = self._block.create_var(
            name=unique_name('array_state_counter'),
            type=core.VarDesc.VarType.LOD_TENSOR,
            dtype='int64')

        # initialize counter
        self._block.append_op(
            type='fill_constant',
            inputs={},
            outputs={'Out': [self._counter]},
            attrs={
                'shape': [1],
                'dtype': self._counter.dtype,
                'value': float(0.0),
                'force_cpu': True
            })

        self._counter.stop_gradient = True

Y
yangyaming 已提交
79
        # write initial state
80 81 82 83 84
        block.append_op(
            type='write_to_array',
            inputs={'X': init_state.value,
                    'I': self._counter},
            outputs={'Out': self._state_array})
Y
yangyaming 已提交
85 86 87 88 89 90 91 92 93 94 95 96

    def get_state(self):
        state = layers.array_read(array=self._state_array, i=self._counter)
        return state

    def update_state(self, state):
        layers.increment(x=self._counter, value=1, in_place=True)
        layers.array_write(state, array=self._state_array, i=self._counter)


class StateCell(object):
    def __init__(self, cell_size, inputs, states, name=None):
97
        self._helper = LayerHelper('state_cell', name=name)
Y
yangyaming 已提交
98 99
        self._cur_states = {}
        self._state_names = []
100
        self._states_holder = {}
Y
yangyaming 已提交
101 102
        for state_name, state in states.items():
            if not isinstance(state, InitState):
103
                raise ValueError('state must be an InitState object.')
Y
yangyaming 已提交
104 105 106 107
            self._cur_states[state_name] = state
            self._state_names.append(state_name)
        self._inputs = inputs  # inputs is place holder here
        self._cur_decoder_obj = None
108 109 110
        self._in_decoder = False
        self._states_holder = {}
        self._switched_decoder = False
Y
yangyaming 已提交
111

112 113 114 115
    def enter_decoder(self, decoder_obj):
        if self._in_decoder == True or self._cur_decoder_obj is not None:
            raise ValueError('StateCell has already entered a decoder.')
        self._in_decoder = True
Y
yangyaming 已提交
116
        self._cur_decoder_obj = decoder_obj
117 118 119 120 121 122 123 124 125
        self._switched_decoder = False

    def _switch_decoder(self):  # lazy switch
        if self._in_decoder == False:
            raise ValueError('StateCell must be enter a decoder.')

        if self._switched_decoder == True:
            raise ValueError('StateCell already done switching.')

Y
yangyaming 已提交
126 127 128
        for state_name in self._state_names:
            if state_name not in self._states_holder:
                state = self._cur_states[state_name]
129

Y
yangyaming 已提交
130
                if not isinstance(state, InitState):
131 132 133 134 135 136 137
                    raise ValueError('Current type of state is %s, should be '
                                     'an InitState object.' % type(state))

                self._states_holder[state_name] = {}

                if self._cur_decoder_obj.type == DecoderType.TRAINING:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
Y
yangyaming 已提交
138
                            MemoryState(state_name,
139
                                        self._cur_decoder_obj.dynamic_rnn,
Y
yangyaming 已提交
140
                                        state)
141 142 143 144 145
                elif self._cur_decoder_obj.type == DecoderType.BEAM_SEARCH:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
                            ArrayState(state_name,
                                       self._cur_decoder_obj.parent_block(),
                                       state)
Y
yangyaming 已提交
146
                else:
147 148 149
                    raise ValueError('Unknown decoder type, only support '
                                     '[TRAINING, BEAM_SEARCH]')

Y
yangyaming 已提交
150 151
            # Read back, since current state should be LoDTensor
            self._cur_states[state_name] = \
152 153 154
                    self._states_holder[state_name][id(self._cur_decoder_obj)].get_state()

        self._switched_decoder = True
Y
yangyaming 已提交
155 156

    def get_state(self, state_name):
157 158 159
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
160 161
        if state_name not in self._cur_states:
            raise ValueError(
162
                'Unknown state %s. Please make sure _switch_decoder() '
Y
yangyaming 已提交
163
                'invoked.' % state_name)
164

Y
yangyaming 已提交
165 166 167 168
        return self._cur_states[state_name]

    def get_input(self, input_name):
        if input_name not in self._inputs or self._inputs[input_name] is None:
169 170
            raise ValueError('Invalid input %s.' % input_name)
        return self._inputs[input_name]
Y
yangyaming 已提交
171 172 173 174 175 176 177 178

    def set_state(self, state_name, state_value):
        self._cur_states[state_name] = state_value

    def register_updater(self, state_updater):
        self._state_updater = state_updater

    def compute_state(self, inputs):
179 180 181
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
182 183 184 185 186 187 188
        for input_name, input_value in inputs.items():
            if input_name not in self._inputs:
                raise ValueError('Unknown input %s. '
                                 'Please make sure %s in input '
                                 'place holder.' % (input_name, input_name))
            self._inputs[input_name] = input_value

189 190 191 192 193 194 195 196 197 198 199 200
        self._state_updater(self)

    def update_states(self):
        if self._in_decoder and not self._switched_decoder:
            self._switched_decoder()

        for state_name, decoder_state in self._states_holder.items():
            if id(self._cur_decoder_obj) not in decoder_state:
                raise ValueError('Unknown decoder object, please make sure '
                                 'switch_decoder been invoked.')
            decoder_state[id(self._cur_decoder_obj)].update_state(
                self._cur_states[state_name])
Y
yangyaming 已提交
201

202 203 204 205 206 207 208 209 210 211 212
    def leave_decoder(self, decoder_obj):
        if self._in_decoder == False:
            raise ValueError('StateCell not in decoder, '
                             'invlid leaving operation.')

        if self._cur_decoder_obj != decoder_obj:
            raise ValueError('Inconsist decoder object in StateCell.')

        self._in_decoder = False
        self._cur_decoder_obj = None
        self._switched_decoder = True
Y
yangyaming 已提交
213 214 215 216 217 218 219 220 221 222 223 224 225


class TrainingDecoder(object):
    BEFORE_DECODER = 0
    IN_DECODER = 1
    AFTER_DECODER = 2

    def __init__(self, state_cell, name=None):
        self._helper = LayerHelper('training_decoder', name=name)
        self._status = TrainingDecoder.BEFORE_DECODER
        self._dynamic_rnn = layers.DynamicRNN()
        self._type = DecoderType.TRAINING
        self._state_cell = state_cell
226
        self._state_cell.enter_decoder(self)
Y
yangyaming 已提交
227 228 229 230

    @contextlib.contextmanager
    def block(self):
        if self._status != TrainingDecoder.BEFORE_DECODER:
231
            raise ValueError('decoder.block() can only be invoked once')
Y
yangyaming 已提交
232 233 234 235
        self._status = TrainingDecoder.IN_DECODER
        with self._dynamic_rnn.block():
            yield
        self._status = TrainingDecoder.AFTER_DECODER
236
        self._state_cell.leave_decoder(self)
Y
yangyaming 已提交
237 238 239

    @property
    def state_cell(self):
240
        self._assert_in_decoder_block('state_cell')
Y
yangyaming 已提交
241 242 243 244 245 246 247 248 249 250 251
        return self._state_cell

    @property
    def dynamic_rnn(self):
        return self._dynamic_rnn

    @property
    def type(self):
        return self._type

    def step_input(self, x):
252
        self._assert_in_decoder_block('step_input')
Y
yangyaming 已提交
253 254 255
        return self._dynamic_rnn.step_input(x)

    def static_input(self, x):
256
        self._assert_in_decoder_block('static_input')
Y
yangyaming 已提交
257 258 259
        return self._dynamic_rnn.static_input(x)

    def __call__(self, *args, **kwargs):
260 261 262
        if self._status != TrainingDecoder.AFTER_DECODER:
            raise ValueError('Output of training decoder can only be visited '
                             'outside the block.')
Y
yangyaming 已提交
263 264 265
        return self._dynamic_rnn(*args, **kwargs)

    def output(self, *outputs):
266 267
        self._assert_in_decoder_block('output')
        self._dynamic_rnn.output(*outputs)
Y
yangyaming 已提交
268 269 270

    def _assert_in_decoder_block(self, method):
        if self._status != TrainingDecoder.IN_DECODER:
271 272
            raise ValueError('%s should be invoked inside block of '
                             'TrainingDecoder object.' % method)
Y
yangyaming 已提交
273 274 275


class BeamSearchDecoder(object):
276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
    BEFORE_BEAM_SEARCH_DECODER = 0
    IN_BEAM_SEARCH_DECODER = 1
    AFTER_BEAM_SEARCH_DECODER = 2

    def __init__(self, state_cell, max_len, name=None):
        self._helper = LayerHelper('beam_search_decoder', name=name)
        self._counter = layers.zeros(shape=[1], dtype='int64')
        self._counter.stop_gradient = True
        self._type = DecoderType.BEAM_SEARCH
        self._max_len = layers.fill_constant(
            shape=[1], dtype='int64', value=max_len)
        self._cond = layers.less_than(
            x=self._counter,
            y=layers.fill_constant(
                shape=[1], dtype='int64', value=max_len))
        self._while_op = layers.While(self._cond)
        self._state_cell = state_cell
        self._state_cell.enter_decoder(self)
        self._status = BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER
        self._zero_idx = layers.fill_constant(
            shape=[1], value=0, dtype='int64', force_cpu=True)
        self._array_dict = {}
        self._array_link = []
        self._ids_array = None
        self._scores_array = None

    @contextlib.contextmanager
    def block(self):
        if self._status != BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER:
            raise ValueError('block() can only be invoke once.')

        self._status = BeamSearchDecoder.IN_BEAM_SEARCH_DECODER

        with self._while_op.block():
            yield

            layers.increment(x=self._counter, value=1.0, in_place=True)

            for value, array in self._array_link:
                layers.array_write(x=value, i=self._counter, array=array)

            layers.less_than(x=self._counter, y=self._max_len, cond=self._cond)

        self._status = BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER
        self._state_cell.leave_decoder(self)

    @property
    def type(self):
        return self._type

    # init must be provided
    def read_array(self, init, is_ids=False, is_scores=False):
        self._assert_in_decoder_block('read_array')

        if is_ids == True and is_scores == True:
            raise ValueError('Shouldn\'t mark current array be ids array and'
                             'scores array at the same time.')

        if not isinstance(init, Variable):
            raise TypeError('The input argument `init` must be a Variable.')

        parent_block = self.parent_block()
        array = parent_block.create_var(
            name=unique_name('beam_search_decoder_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init.dtype)
        parent_block.append_op(
            type='write_to_array',
            inputs={'X': init,
                    'I': self._zero_idx},
            outputs={'Out': array})

        if is_ids == True:
            self._ids_array = array
        elif is_scores == True:
            self._scores_array = array

        read_value = layers.array_read(array=array, i=self._counter)
        self._array_dict[read_value.name] = array
        return read_value

    def update_array(self, array, value):
        self._assert_in_decoder_block('update_array')

        if not isinstance(array, Variable):
            raise TypeError(
                'The input argument `array` of  must be a Variable.')
        if not isinstance(value, Variable):
            raise TypeError('The input argument `value` of must be a Variable.')

        array = self._array_dict.get(array.name, None)
        if array is None:
            raise ValueError('Please invoke read_array before update_array.')
        self._array_link.append((value, array))

    def __call__(self):
        if self._status != BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER:
            raise ValueError('Output of BeamSearchDecoder object can '
                             'only be visited outside the block.')
        return layers.beam_search_decode(
            ids=self._ids_array, scores=self._scores_array)

    @property
    def state_cell(self):
        self._assert_in_decoder_block('state_cell')
        return self._state_cell

    def parent_block(self):
        program = self._helper.main_program
        parent_block_idx = program.current_block().parent_idx
        if parent_block_idx < 0:
            raise ValueError('Invlid block with index %d.' % parent_block_idx)
        parent_block = program.block(parent_block_idx)
        return parent_block

    def _assert_in_decoder_block(self, method):
        if self._status != BeamSearchDecoder.IN_BEAM_SEARCH_DECODER:
            raise ValueError('%s should be invoked inside block of '
                             'BeamSearchDecoder object.' % method)