beam_search_api.py 14.3 KB
Newer Older
1 2 3
import paddle.fluid as fluid
import paddle.fluid.layers as layers
from paddle.fluid.framework import Variable
Y
yangyaming 已提交
4
import contextlib
5 6
from paddle.fluid.layer_helper import LayerHelper, unique_name
import paddle.fluid.core as core
Y
yangyaming 已提交
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30


class DecoderType:
    TRAINING = 1
    BEAM_SEARCH = 2


class InitState(object):
    def __init__(self,
                 init=None,
                 shape=None,
                 value=0.0,
                 need_reorder=False,
                 dtype='float32'):
        self._init = init
        self._shape = shape
        self._value = value
        self._need_reorder = need_reorder
        self._dtype = dtype

    @property
    def value(self):
        return self._init  # may create a LoDTensor

31 32 33 34
    @property
    def need_reorder(self):
        return self._need_reorder

Y
yangyaming 已提交
35 36 37 38 39

class MemoryState(object):
    def __init__(self, state_name, rnn_obj, init_state):
        self._state_name = state_name  # each is a rnn.memory
        self._rnn_obj = rnn_obj
40 41
        self._state_mem = self._rnn_obj.memory(
            init=init_state.value, need_reorder=init_state.need_reorder)
Y
yangyaming 已提交
42 43 44 45 46 47 48 49 50

    def get_state(self):
        return self._state_mem

    def update_state(self, state):
        self._rnn_obj.update_memory(self._state_mem, state)


class ArrayState(object):
51
    def __init__(self, state_name, block, init_state):
Y
yangyaming 已提交
52
        self._state_name = state_name
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78
        self._block = block

        self._state_array = self._block.create_var(
            name=unique_name('array_state_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init_state.value.dtype)

        self._counter = self._block.create_var(
            name=unique_name('array_state_counter'),
            type=core.VarDesc.VarType.LOD_TENSOR,
            dtype='int64')

        # initialize counter
        self._block.append_op(
            type='fill_constant',
            inputs={},
            outputs={'Out': [self._counter]},
            attrs={
                'shape': [1],
                'dtype': self._counter.dtype,
                'value': float(0.0),
                'force_cpu': True
            })

        self._counter.stop_gradient = True

Y
yangyaming 已提交
79
        # write initial state
80 81 82 83 84
        block.append_op(
            type='write_to_array',
            inputs={'X': init_state.value,
                    'I': self._counter},
            outputs={'Out': self._state_array})
Y
yangyaming 已提交
85 86 87 88 89 90 91 92 93 94 95 96

    def get_state(self):
        state = layers.array_read(array=self._state_array, i=self._counter)
        return state

    def update_state(self, state):
        layers.increment(x=self._counter, value=1, in_place=True)
        layers.array_write(state, array=self._state_array, i=self._counter)


class StateCell(object):
    def __init__(self, cell_size, inputs, states, name=None):
97
        self._helper = LayerHelper('state_cell', name=name)
Y
yangyaming 已提交
98 99
        self._cur_states = {}
        self._state_names = []
100
        self._states_holder = {}
Y
yangyaming 已提交
101 102
        for state_name, state in states.items():
            if not isinstance(state, InitState):
103
                raise ValueError('state must be an InitState object.')
Y
yangyaming 已提交
104 105 106 107
            self._cur_states[state_name] = state
            self._state_names.append(state_name)
        self._inputs = inputs  # inputs is place holder here
        self._cur_decoder_obj = None
108 109 110
        self._in_decoder = False
        self._states_holder = {}
        self._switched_decoder = False
Y
yangyaming 已提交
111
        self._state_updater = None
Y
yangyaming 已提交
112

113 114 115 116
    def enter_decoder(self, decoder_obj):
        if self._in_decoder == True or self._cur_decoder_obj is not None:
            raise ValueError('StateCell has already entered a decoder.')
        self._in_decoder = True
Y
yangyaming 已提交
117
        self._cur_decoder_obj = decoder_obj
118 119 120 121 122 123 124 125 126
        self._switched_decoder = False

    def _switch_decoder(self):  # lazy switch
        if self._in_decoder == False:
            raise ValueError('StateCell must be enter a decoder.')

        if self._switched_decoder == True:
            raise ValueError('StateCell already done switching.')

Y
yangyaming 已提交
127 128 129
        for state_name in self._state_names:
            if state_name not in self._states_holder:
                state = self._cur_states[state_name]
130

Y
yangyaming 已提交
131
                if not isinstance(state, InitState):
132 133 134 135 136 137 138
                    raise ValueError('Current type of state is %s, should be '
                                     'an InitState object.' % type(state))

                self._states_holder[state_name] = {}

                if self._cur_decoder_obj.type == DecoderType.TRAINING:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
Y
yangyaming 已提交
139
                            MemoryState(state_name,
140
                                        self._cur_decoder_obj.dynamic_rnn,
Y
yangyaming 已提交
141
                                        state)
142 143 144 145 146
                elif self._cur_decoder_obj.type == DecoderType.BEAM_SEARCH:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
                            ArrayState(state_name,
                                       self._cur_decoder_obj.parent_block(),
                                       state)
Y
yangyaming 已提交
147
                else:
148 149 150
                    raise ValueError('Unknown decoder type, only support '
                                     '[TRAINING, BEAM_SEARCH]')

Y
yangyaming 已提交
151 152
            # Read back, since current state should be LoDTensor
            self._cur_states[state_name] = \
153 154 155
                    self._states_holder[state_name][id(self._cur_decoder_obj)].get_state()

        self._switched_decoder = True
Y
yangyaming 已提交
156 157

    def get_state(self, state_name):
158 159 160
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
161 162
        if state_name not in self._cur_states:
            raise ValueError(
163
                'Unknown state %s. Please make sure _switch_decoder() '
Y
yangyaming 已提交
164
                'invoked.' % state_name)
165

Y
yangyaming 已提交
166 167 168 169
        return self._cur_states[state_name]

    def get_input(self, input_name):
        if input_name not in self._inputs or self._inputs[input_name] is None:
170 171
            raise ValueError('Invalid input %s.' % input_name)
        return self._inputs[input_name]
Y
yangyaming 已提交
172 173 174 175

    def set_state(self, state_name, state_value):
        self._cur_states[state_name] = state_value

Y
yangyaming 已提交
176 177 178 179 180 181 182 183 184 185
    def state_updater(self, updater):
        self._state_updater = updater

        def _decorator(state_cell):
            if state_cell == self:
                raise TypeError('Updater should only accept a StateCell object '
                                'as argument.')
            updater(state_cell)

        return _decorator
Y
yangyaming 已提交
186 187

    def compute_state(self, inputs):
188 189 190
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
191 192 193 194 195 196 197
        for input_name, input_value in inputs.items():
            if input_name not in self._inputs:
                raise ValueError('Unknown input %s. '
                                 'Please make sure %s in input '
                                 'place holder.' % (input_name, input_name))
            self._inputs[input_name] = input_value

198 199 200 201 202 203 204 205 206 207 208 209
        self._state_updater(self)

    def update_states(self):
        if self._in_decoder and not self._switched_decoder:
            self._switched_decoder()

        for state_name, decoder_state in self._states_holder.items():
            if id(self._cur_decoder_obj) not in decoder_state:
                raise ValueError('Unknown decoder object, please make sure '
                                 'switch_decoder been invoked.')
            decoder_state[id(self._cur_decoder_obj)].update_state(
                self._cur_states[state_name])
Y
yangyaming 已提交
210

211 212 213 214 215 216 217 218 219 220 221
    def leave_decoder(self, decoder_obj):
        if self._in_decoder == False:
            raise ValueError('StateCell not in decoder, '
                             'invlid leaving operation.')

        if self._cur_decoder_obj != decoder_obj:
            raise ValueError('Inconsist decoder object in StateCell.')

        self._in_decoder = False
        self._cur_decoder_obj = None
        self._switched_decoder = True
Y
yangyaming 已提交
222 223 224 225 226 227 228 229 230 231 232 233 234


class TrainingDecoder(object):
    BEFORE_DECODER = 0
    IN_DECODER = 1
    AFTER_DECODER = 2

    def __init__(self, state_cell, name=None):
        self._helper = LayerHelper('training_decoder', name=name)
        self._status = TrainingDecoder.BEFORE_DECODER
        self._dynamic_rnn = layers.DynamicRNN()
        self._type = DecoderType.TRAINING
        self._state_cell = state_cell
235
        self._state_cell.enter_decoder(self)
Y
yangyaming 已提交
236 237 238 239

    @contextlib.contextmanager
    def block(self):
        if self._status != TrainingDecoder.BEFORE_DECODER:
240
            raise ValueError('decoder.block() can only be invoked once')
Y
yangyaming 已提交
241 242 243 244
        self._status = TrainingDecoder.IN_DECODER
        with self._dynamic_rnn.block():
            yield
        self._status = TrainingDecoder.AFTER_DECODER
245
        self._state_cell.leave_decoder(self)
Y
yangyaming 已提交
246 247 248

    @property
    def state_cell(self):
249
        self._assert_in_decoder_block('state_cell')
Y
yangyaming 已提交
250 251 252 253 254 255 256 257 258 259 260
        return self._state_cell

    @property
    def dynamic_rnn(self):
        return self._dynamic_rnn

    @property
    def type(self):
        return self._type

    def step_input(self, x):
261
        self._assert_in_decoder_block('step_input')
Y
yangyaming 已提交
262 263 264
        return self._dynamic_rnn.step_input(x)

    def static_input(self, x):
265
        self._assert_in_decoder_block('static_input')
Y
yangyaming 已提交
266 267 268
        return self._dynamic_rnn.static_input(x)

    def __call__(self, *args, **kwargs):
269 270 271
        if self._status != TrainingDecoder.AFTER_DECODER:
            raise ValueError('Output of training decoder can only be visited '
                             'outside the block.')
Y
yangyaming 已提交
272 273 274
        return self._dynamic_rnn(*args, **kwargs)

    def output(self, *outputs):
275 276
        self._assert_in_decoder_block('output')
        self._dynamic_rnn.output(*outputs)
Y
yangyaming 已提交
277 278 279

    def _assert_in_decoder_block(self, method):
        if self._status != TrainingDecoder.IN_DECODER:
280 281
            raise ValueError('%s should be invoked inside block of '
                             'TrainingDecoder object.' % method)
Y
yangyaming 已提交
282 283 284


class BeamSearchDecoder(object):
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403
    BEFORE_BEAM_SEARCH_DECODER = 0
    IN_BEAM_SEARCH_DECODER = 1
    AFTER_BEAM_SEARCH_DECODER = 2

    def __init__(self, state_cell, max_len, name=None):
        self._helper = LayerHelper('beam_search_decoder', name=name)
        self._counter = layers.zeros(shape=[1], dtype='int64')
        self._counter.stop_gradient = True
        self._type = DecoderType.BEAM_SEARCH
        self._max_len = layers.fill_constant(
            shape=[1], dtype='int64', value=max_len)
        self._cond = layers.less_than(
            x=self._counter,
            y=layers.fill_constant(
                shape=[1], dtype='int64', value=max_len))
        self._while_op = layers.While(self._cond)
        self._state_cell = state_cell
        self._state_cell.enter_decoder(self)
        self._status = BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER
        self._zero_idx = layers.fill_constant(
            shape=[1], value=0, dtype='int64', force_cpu=True)
        self._array_dict = {}
        self._array_link = []
        self._ids_array = None
        self._scores_array = None

    @contextlib.contextmanager
    def block(self):
        if self._status != BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER:
            raise ValueError('block() can only be invoke once.')

        self._status = BeamSearchDecoder.IN_BEAM_SEARCH_DECODER

        with self._while_op.block():
            yield

            layers.increment(x=self._counter, value=1.0, in_place=True)

            for value, array in self._array_link:
                layers.array_write(x=value, i=self._counter, array=array)

            layers.less_than(x=self._counter, y=self._max_len, cond=self._cond)

        self._status = BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER
        self._state_cell.leave_decoder(self)

    @property
    def type(self):
        return self._type

    # init must be provided
    def read_array(self, init, is_ids=False, is_scores=False):
        self._assert_in_decoder_block('read_array')

        if is_ids == True and is_scores == True:
            raise ValueError('Shouldn\'t mark current array be ids array and'
                             'scores array at the same time.')

        if not isinstance(init, Variable):
            raise TypeError('The input argument `init` must be a Variable.')

        parent_block = self.parent_block()
        array = parent_block.create_var(
            name=unique_name('beam_search_decoder_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init.dtype)
        parent_block.append_op(
            type='write_to_array',
            inputs={'X': init,
                    'I': self._zero_idx},
            outputs={'Out': array})

        if is_ids == True:
            self._ids_array = array
        elif is_scores == True:
            self._scores_array = array

        read_value = layers.array_read(array=array, i=self._counter)
        self._array_dict[read_value.name] = array
        return read_value

    def update_array(self, array, value):
        self._assert_in_decoder_block('update_array')

        if not isinstance(array, Variable):
            raise TypeError(
                'The input argument `array` of  must be a Variable.')
        if not isinstance(value, Variable):
            raise TypeError('The input argument `value` of must be a Variable.')

        array = self._array_dict.get(array.name, None)
        if array is None:
            raise ValueError('Please invoke read_array before update_array.')
        self._array_link.append((value, array))

    def __call__(self):
        if self._status != BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER:
            raise ValueError('Output of BeamSearchDecoder object can '
                             'only be visited outside the block.')
        return layers.beam_search_decode(
            ids=self._ids_array, scores=self._scores_array)

    @property
    def state_cell(self):
        self._assert_in_decoder_block('state_cell')
        return self._state_cell

    def parent_block(self):
        program = self._helper.main_program
        parent_block_idx = program.current_block().parent_idx
        if parent_block_idx < 0:
            raise ValueError('Invlid block with index %d.' % parent_block_idx)
        parent_block = program.block(parent_block_idx)
        return parent_block

    def _assert_in_decoder_block(self, method):
        if self._status != BeamSearchDecoder.IN_BEAM_SEARCH_DECODER:
            raise ValueError('%s should be invoked inside block of '
                             'BeamSearchDecoder object.' % method)