beam_search_api.py 13.9 KB
Newer Older
Y
yangyaming 已提交
1 2
import paddle.v2.fluid as fluid
import paddle.v2.fluid.layers as layers
3
from paddle.v2.fluid.framework import Variable
Y
yangyaming 已提交
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
import contextlib
from paddle.v2.fluid.layer_helper import LayerHelper, unique_name
import paddle.v2.fluid.core as core


class DecoderType:
    TRAINING = 1
    BEAM_SEARCH = 2


class InitState(object):
    def __init__(self,
                 init=None,
                 shape=None,
                 value=0.0,
                 need_reorder=False,
                 dtype='float32'):
        self._init = init
        self._shape = shape
        self._value = value
        self._need_reorder = need_reorder
        self._dtype = dtype

    @property
    def value(self):
        return self._init  # may create a LoDTensor


class MemoryState(object):
    def __init__(self, state_name, rnn_obj, init_state):
        self._state_name = state_name  # each is a rnn.memory
        self._rnn_obj = rnn_obj
        self._state_mem = self._rnn_obj.memory(init=init_state.value)

    def get_state(self):
        return self._state_mem

    def update_state(self, state):
        self._rnn_obj.update_memory(self._state_mem, state)


class ArrayState(object):
46
    def __init__(self, state_name, block, init_state):
Y
yangyaming 已提交
47
        self._state_name = state_name
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
        self._block = block

        self._state_array = self._block.create_var(
            name=unique_name('array_state_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init_state.value.dtype)

        self._counter = self._block.create_var(
            name=unique_name('array_state_counter'),
            type=core.VarDesc.VarType.LOD_TENSOR,
            dtype='int64')

        # initialize counter
        self._block.append_op(
            type='fill_constant',
            inputs={},
            outputs={'Out': [self._counter]},
            attrs={
                'shape': [1],
                'dtype': self._counter.dtype,
                'value': float(0.0),
                'force_cpu': True
            })

        self._counter.stop_gradient = True

Y
yangyaming 已提交
74
        # write initial state
75 76 77 78 79
        block.append_op(
            type='write_to_array',
            inputs={'X': init_state.value,
                    'I': self._counter},
            outputs={'Out': self._state_array})
Y
yangyaming 已提交
80 81 82 83 84 85 86 87 88 89 90 91

    def get_state(self):
        state = layers.array_read(array=self._state_array, i=self._counter)
        return state

    def update_state(self, state):
        layers.increment(x=self._counter, value=1, in_place=True)
        layers.array_write(state, array=self._state_array, i=self._counter)


class StateCell(object):
    def __init__(self, cell_size, inputs, states, name=None):
92
        self._helper = LayerHelper('state_cell', name=name)
Y
yangyaming 已提交
93 94
        self._cur_states = {}
        self._state_names = []
95
        self._states_holder = {}
Y
yangyaming 已提交
96 97
        for state_name, state in states.items():
            if not isinstance(state, InitState):
98
                raise ValueError('state must be an InitState object.')
Y
yangyaming 已提交
99 100 101 102
            self._cur_states[state_name] = state
            self._state_names.append(state_name)
        self._inputs = inputs  # inputs is place holder here
        self._cur_decoder_obj = None
103 104 105
        self._in_decoder = False
        self._states_holder = {}
        self._switched_decoder = False
Y
yangyaming 已提交
106

107 108 109 110
    def enter_decoder(self, decoder_obj):
        if self._in_decoder == True or self._cur_decoder_obj is not None:
            raise ValueError('StateCell has already entered a decoder.')
        self._in_decoder = True
Y
yangyaming 已提交
111
        self._cur_decoder_obj = decoder_obj
112 113 114 115 116 117 118 119 120
        self._switched_decoder = False

    def _switch_decoder(self):  # lazy switch
        if self._in_decoder == False:
            raise ValueError('StateCell must be enter a decoder.')

        if self._switched_decoder == True:
            raise ValueError('StateCell already done switching.')

Y
yangyaming 已提交
121 122 123
        for state_name in self._state_names:
            if state_name not in self._states_holder:
                state = self._cur_states[state_name]
124

Y
yangyaming 已提交
125
                if not isinstance(state, InitState):
126 127 128 129 130 131 132
                    raise ValueError('Current type of state is %s, should be '
                                     'an InitState object.' % type(state))

                self._states_holder[state_name] = {}

                if self._cur_decoder_obj.type == DecoderType.TRAINING:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
Y
yangyaming 已提交
133
                            MemoryState(state_name,
134
                                        self._cur_decoder_obj.dynamic_rnn,
Y
yangyaming 已提交
135
                                        state)
136 137 138 139 140
                elif self._cur_decoder_obj.type == DecoderType.BEAM_SEARCH:
                    self._states_holder[state_name][id(self._cur_decoder_obj)] = \
                            ArrayState(state_name,
                                       self._cur_decoder_obj.parent_block(),
                                       state)
Y
yangyaming 已提交
141
                else:
142 143 144
                    raise ValueError('Unknown decoder type, only support '
                                     '[TRAINING, BEAM_SEARCH]')

Y
yangyaming 已提交
145 146
            # Read back, since current state should be LoDTensor
            self._cur_states[state_name] = \
147 148 149
                    self._states_holder[state_name][id(self._cur_decoder_obj)].get_state()

        self._switched_decoder = True
Y
yangyaming 已提交
150 151

    def get_state(self, state_name):
152 153 154
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
155 156
        if state_name not in self._cur_states:
            raise ValueError(
157
                'Unknown state %s. Please make sure _switch_decoder() '
Y
yangyaming 已提交
158
                'invoked.' % state_name)
159

Y
yangyaming 已提交
160 161 162 163
        return self._cur_states[state_name]

    def get_input(self, input_name):
        if input_name not in self._inputs or self._inputs[input_name] is None:
164 165
            raise ValueError('Invalid input %s.' % input_name)
        return self._inputs[input_name]
Y
yangyaming 已提交
166 167 168 169 170 171 172 173

    def set_state(self, state_name, state_value):
        self._cur_states[state_name] = state_value

    def register_updater(self, state_updater):
        self._state_updater = state_updater

    def compute_state(self, inputs):
174 175 176
        if self._in_decoder and not self._switched_decoder:
            self._switch_decoder()

Y
yangyaming 已提交
177 178 179 180 181 182 183
        for input_name, input_value in inputs.items():
            if input_name not in self._inputs:
                raise ValueError('Unknown input %s. '
                                 'Please make sure %s in input '
                                 'place holder.' % (input_name, input_name))
            self._inputs[input_name] = input_value

184 185 186 187 188 189 190 191 192 193 194 195
        self._state_updater(self)

    def update_states(self):
        if self._in_decoder and not self._switched_decoder:
            self._switched_decoder()

        for state_name, decoder_state in self._states_holder.items():
            if id(self._cur_decoder_obj) not in decoder_state:
                raise ValueError('Unknown decoder object, please make sure '
                                 'switch_decoder been invoked.')
            decoder_state[id(self._cur_decoder_obj)].update_state(
                self._cur_states[state_name])
Y
yangyaming 已提交
196

197 198 199 200 201 202 203 204 205 206 207
    def leave_decoder(self, decoder_obj):
        if self._in_decoder == False:
            raise ValueError('StateCell not in decoder, '
                             'invlid leaving operation.')

        if self._cur_decoder_obj != decoder_obj:
            raise ValueError('Inconsist decoder object in StateCell.')

        self._in_decoder = False
        self._cur_decoder_obj = None
        self._switched_decoder = True
Y
yangyaming 已提交
208 209 210 211 212 213 214 215 216 217 218 219 220


class TrainingDecoder(object):
    BEFORE_DECODER = 0
    IN_DECODER = 1
    AFTER_DECODER = 2

    def __init__(self, state_cell, name=None):
        self._helper = LayerHelper('training_decoder', name=name)
        self._status = TrainingDecoder.BEFORE_DECODER
        self._dynamic_rnn = layers.DynamicRNN()
        self._type = DecoderType.TRAINING
        self._state_cell = state_cell
221
        self._state_cell.enter_decoder(self)
Y
yangyaming 已提交
222 223 224 225

    @contextlib.contextmanager
    def block(self):
        if self._status != TrainingDecoder.BEFORE_DECODER:
226
            raise ValueError('decoder.block() can only be invoked once')
Y
yangyaming 已提交
227 228 229 230
        self._status = TrainingDecoder.IN_DECODER
        with self._dynamic_rnn.block():
            yield
        self._status = TrainingDecoder.AFTER_DECODER
231
        self._state_cell.leave_decoder(self)
Y
yangyaming 已提交
232 233 234

    @property
    def state_cell(self):
235
        self._assert_in_decoder_block('state_cell')
Y
yangyaming 已提交
236 237 238 239 240 241 242 243 244 245 246
        return self._state_cell

    @property
    def dynamic_rnn(self):
        return self._dynamic_rnn

    @property
    def type(self):
        return self._type

    def step_input(self, x):
247
        self._assert_in_decoder_block('step_input')
Y
yangyaming 已提交
248 249 250
        return self._dynamic_rnn.step_input(x)

    def static_input(self, x):
251
        self._assert_in_decoder_block('static_input')
Y
yangyaming 已提交
252 253 254
        return self._dynamic_rnn.static_input(x)

    def __call__(self, *args, **kwargs):
255 256 257
        if self._status != TrainingDecoder.AFTER_DECODER:
            raise ValueError('Output of training decoder can only be visited '
                             'outside the block.')
Y
yangyaming 已提交
258 259 260
        return self._dynamic_rnn(*args, **kwargs)

    def output(self, *outputs):
261 262
        self._assert_in_decoder_block('output')
        self._dynamic_rnn.output(*outputs)
Y
yangyaming 已提交
263 264 265

    def _assert_in_decoder_block(self, method):
        if self._status != TrainingDecoder.IN_DECODER:
266 267
            raise ValueError('%s should be invoked inside block of '
                             'TrainingDecoder object.' % method)
Y
yangyaming 已提交
268 269 270


class BeamSearchDecoder(object):
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
    BEFORE_BEAM_SEARCH_DECODER = 0
    IN_BEAM_SEARCH_DECODER = 1
    AFTER_BEAM_SEARCH_DECODER = 2

    def __init__(self, state_cell, max_len, name=None):
        self._helper = LayerHelper('beam_search_decoder', name=name)
        self._counter = layers.zeros(shape=[1], dtype='int64')
        self._counter.stop_gradient = True
        self._type = DecoderType.BEAM_SEARCH
        self._max_len = layers.fill_constant(
            shape=[1], dtype='int64', value=max_len)
        self._cond = layers.less_than(
            x=self._counter,
            y=layers.fill_constant(
                shape=[1], dtype='int64', value=max_len))
        self._while_op = layers.While(self._cond)
        self._state_cell = state_cell
        self._state_cell.enter_decoder(self)
        self._status = BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER
        self._zero_idx = layers.fill_constant(
            shape=[1], value=0, dtype='int64', force_cpu=True)
        self._array_dict = {}
        self._array_link = []
        self._ids_array = None
        self._scores_array = None

    @contextlib.contextmanager
    def block(self):
        if self._status != BeamSearchDecoder.BEFORE_BEAM_SEARCH_DECODER:
            raise ValueError('block() can only be invoke once.')

        self._status = BeamSearchDecoder.IN_BEAM_SEARCH_DECODER

        with self._while_op.block():
            yield

            layers.increment(x=self._counter, value=1.0, in_place=True)

            for value, array in self._array_link:
                layers.array_write(x=value, i=self._counter, array=array)

            layers.less_than(x=self._counter, y=self._max_len, cond=self._cond)

        self._status = BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER
        self._state_cell.leave_decoder(self)

    @property
    def type(self):
        return self._type

    # init must be provided
    def read_array(self, init, is_ids=False, is_scores=False):
        self._assert_in_decoder_block('read_array')

        if is_ids == True and is_scores == True:
            raise ValueError('Shouldn\'t mark current array be ids array and'
                             'scores array at the same time.')

        if not isinstance(init, Variable):
            raise TypeError('The input argument `init` must be a Variable.')

        parent_block = self.parent_block()
        array = parent_block.create_var(
            name=unique_name('beam_search_decoder_array'),
            type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
            dtype=init.dtype)
        parent_block.append_op(
            type='write_to_array',
            inputs={'X': init,
                    'I': self._zero_idx},
            outputs={'Out': array})

        if is_ids == True:
            self._ids_array = array
        elif is_scores == True:
            self._scores_array = array

        read_value = layers.array_read(array=array, i=self._counter)
        self._array_dict[read_value.name] = array
        return read_value

    def update_array(self, array, value):
        self._assert_in_decoder_block('update_array')

        if not isinstance(array, Variable):
            raise TypeError(
                'The input argument `array` of  must be a Variable.')
        if not isinstance(value, Variable):
            raise TypeError('The input argument `value` of must be a Variable.')

        array = self._array_dict.get(array.name, None)
        if array is None:
            raise ValueError('Please invoke read_array before update_array.')
        self._array_link.append((value, array))

    def __call__(self):
        if self._status != BeamSearchDecoder.AFTER_BEAM_SEARCH_DECODER:
            raise ValueError('Output of BeamSearchDecoder object can '
                             'only be visited outside the block.')
        return layers.beam_search_decode(
            ids=self._ids_array, scores=self._scores_array)

    @property
    def state_cell(self):
        self._assert_in_decoder_block('state_cell')
        return self._state_cell

    def parent_block(self):
        program = self._helper.main_program
        parent_block_idx = program.current_block().parent_idx
        if parent_block_idx < 0:
            raise ValueError('Invlid block with index %d.' % parent_block_idx)
        parent_block = program.block(parent_block_idx)
        return parent_block

    def _assert_in_decoder_block(self, method):
        if self._status != BeamSearchDecoder.IN_BEAM_SEARCH_DECODER:
            raise ValueError('%s should be invoked inside block of '
                             'BeamSearchDecoder object.' % method)