提交 51b6ecfc 编写于 作者: Y yangyaming

Replace callback with decorator.

上级 2b022f0b
...@@ -181,7 +181,8 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, ...@@ -181,7 +181,8 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
context = fluid.layers.sequence_pool(input=scaled, pool_type='sum') context = fluid.layers.sequence_pool(input=scaled, pool_type='sum')
return context return context
def updater(state_cell): @state_cell.state_updater
def state_updater(state_cell):
current_word = state_cell.get_input('x') current_word = state_cell.get_input('x')
encoder_vec = state_cell.get_input('encoder_vec') encoder_vec = state_cell.get_input('encoder_vec')
encoder_proj = state_cell.get_input('encoder_proj') encoder_proj = state_cell.get_input('encoder_proj')
...@@ -194,8 +195,6 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, ...@@ -194,8 +195,6 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
state_cell.set_state('h', h) state_cell.set_state('h', h)
state_cell.set_state('c', c) state_cell.set_state('c', c)
state_cell.register_updater(updater)
if not is_generating: if not is_generating:
trg_word_idx = fluid.layers.data( trg_word_idx = fluid.layers.data(
name='target_sequence', shape=[1], dtype='int64', lod_level=1) name='target_sequence', shape=[1], dtype='int64', lod_level=1)
...@@ -233,7 +232,68 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim, ...@@ -233,7 +232,68 @@ def seq_to_seq_net(embedding_dim, encoder_size, decoder_size, source_dict_dim,
return avg_cost, feeding_list return avg_cost, feeding_list
else: else:
pass init_ids = fluid.layers.data(
name="init_ids", shape=[1], dtype="int64", lod_level=2)
init_scores = fluid.layers.data(
name="init_scores", shape=[1], dtype="float32", lod_level=2)
'''
src_embedding = fluid.layers.embedding(
input=src_word_idx,
size=[source_dict_dim, embedding_dim],
dtype='float32')
'''
src_embedding = fluid.layers.embedding(
input=src_word_idx,
size=[source_dict_dim, embedding_dim],
dtype='float32',
ParamAttr=())
decoder = BeamSearchDecoder(state_cell, max_len=max_length)
with decoder.block():
# encoder_vec = prev_scores
# encoder_proj = prev_scores
prev_ids = decoder.read_array(init=init_ids, is_ids=True)
prev_scores = decoder.read_array(init=init_scores, is_scores=True)
# need make sure the weight shared
prev_ids_embedding = fluid.layers.embedding(prev_ids)
prev_h = decoder.state_cell.get_state('h')
prev_c = decoder.state_cell.get_state('c')
prev_h_expanded = fluid.layers.sequence_expand(prev_h, prev_scores)
prev_c_expanded = fluid.layers.sequence_expand(prev_c, prev_scores)
decoder.state_cell.set_state('h', prev_h_expanded)
decoder.state_cell.set_state('c', prev_c_expanded)
decoder.state_cell.compute_state(inputs={
'x': prev_ids_embedding,
'encoder_vec': None,
'encoder_proj': None
})
current_state = decoder.state_cell.get_state('h')
scores = fluid.layers.fc(input=current_state,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = fluid.layers.topk(scores, k=beam_size)
selected_ids, selected_scores = fluid.layers.beam_search(
prev_ids,
topk_indices,
topk_scores,
beam_size,
end_id=10,
level=0)
decoder.state_cell.update_states()
decoder.update_array(prev_ids, selected_ids)
decoder.update_array(prev_scores, selected_scores)
translation_ids, translation_scores = decoder()
feeding_list = [
"source_sequence", "target_sequence", "init_ids", "init_scores"
]
return translation_ids, translation_scores, feeding_list
def to_lodtensor(data, place): def to_lodtensor(data, place):
...@@ -345,7 +405,43 @@ def train(): ...@@ -345,7 +405,43 @@ def train():
def infer(): def infer():
pass translation_ids, translation_scores, feeding_list = seq_to_seq_net(
args.embedding_dim,
args.encoder_size,
args.decoder_size,
args.dict_size,
args.dict_size,
True,
beam_size=args.beam_size,
max_length=args.max_length)
fluid.memory_optimize(fluid.default_main_program(), print_log=False)
test_batch_generator = paddle.v2.batch(
paddle.v2.reader.shuffle(
paddle.v2.dataset.wmt14.test(args.dict_size), buf_size=1000),
batch_size=args.batch_size)
place = core.CUDAPlace(0) if args.use_gpu else core.CPUPlace()
exe = Executor(place)
exe.run(framework.default_startup_program())
for batch_id, data in enumerate(test_batch_generator()):
src_seq, word_num = to_lodtensor(map(lambda x: x[0], data), place)
trg_seq, word_num = to_lodtensor(map(lambda x: x[1], data), place)
lbl_seq, _ = to_lodtensor(map(lambda x: x[2], data), place)
fetch_outs = exe.run(framework.default_main_program(),
feed={
feeding_list[0]: src_seq,
feeding_list[1]: trg_seq,
feeding_list[2]: lbl_seq
},
fetch_list=[avg_cost])
avg_cost_val = np.array(fetch_outs[0])
print('pass_id=%d, batch_id=%d, train_loss: %f' % (pass_id, batch_id,
avg_cost_val))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -108,6 +108,7 @@ class StateCell(object): ...@@ -108,6 +108,7 @@ class StateCell(object):
self._in_decoder = False self._in_decoder = False
self._states_holder = {} self._states_holder = {}
self._switched_decoder = False self._switched_decoder = False
self._state_updater = None
def enter_decoder(self, decoder_obj): def enter_decoder(self, decoder_obj):
if self._in_decoder == True or self._cur_decoder_obj is not None: if self._in_decoder == True or self._cur_decoder_obj is not None:
...@@ -172,8 +173,16 @@ class StateCell(object): ...@@ -172,8 +173,16 @@ class StateCell(object):
def set_state(self, state_name, state_value): def set_state(self, state_name, state_value):
self._cur_states[state_name] = state_value self._cur_states[state_name] = state_value
def register_updater(self, state_updater): def state_updater(self, updater):
self._state_updater = state_updater self._state_updater = updater
def _decorator(state_cell):
if state_cell == self:
raise TypeError('Updater should only accept a StateCell object '
'as argument.')
updater(state_cell)
return _decorator
def compute_state(self, inputs): def compute_state(self, inputs):
if self._in_decoder and not self._switched_decoder: if self._in_decoder and not self._switched_decoder:
......
...@@ -55,18 +55,17 @@ def encoder(): ...@@ -55,18 +55,17 @@ def encoder():
return encoder_out return encoder_out
def updater(state_cell):
current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h')
h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh')
state_cell.set_state('h', h)
def decoder_train(context): def decoder_train(context):
h = InitState(init=context) h = InitState(init=context)
state_cell = StateCell( state_cell = StateCell(
cell_size=decoder_size, inputs={'x': None}, states={'h': h}) cell_size=decoder_size, inputs={'x': None}, states={'h': h})
state_cell.register_updater(updater)
@state_cell.state_updater
def updater(state_cell):
current_word = state_cell.get_input('x')
prev_h = state_cell.get_state('h')
h = pd.fc(input=[current_word, prev_h], size=decoder_size, act='tanh')
state_cell.set_state('h', h)
# decoder # decoder
trg_language_word = pd.data( trg_language_word = pd.data(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册