提交 07a8f0ef 编写于 作者: Q qiaolongfei

refine code, remove beam_search.py

上级 bf6fd470
import sys import sys
import paddle.v2 as paddle import paddle.v2 as paddle
import paddle.v2.layers.beam_search as beam_search
def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
...@@ -106,13 +106,13 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating): ...@@ -106,13 +106,13 @@ def seqToseq_net(source_dict_dim, target_dict_dim, is_generating):
# GeneratedInputs, which is initialized by a start mark, such as <s>, # GeneratedInputs, which is initialized by a start mark, such as <s>,
# and must be included in generation. # and must be included in generation.
trg_embedding = beam_search.GeneratedInputV2( trg_embedding = paddle.layer.GeneratedInputV2(
size=target_dict_dim, size=target_dict_dim,
embedding_name='_target_language_embedding', embedding_name='_target_language_embedding',
embedding_size=word_vector_dim) embedding_size=word_vector_dim)
group_inputs.append(trg_embedding) group_inputs.append(trg_embedding)
beam_gen = beam_search.beam_search( beam_gen = paddle.layer.beam_search(
name=decoder_group_name, name=decoder_group_name,
step=gru_decoder_with_attention, step=gru_decoder_with_attention,
input=group_inputs, input=group_inputs,
......
...@@ -87,19 +87,19 @@ class Layer(object): ...@@ -87,19 +87,19 @@ class Layer(object):
""" """
self.__context__ = context self.__context__ = context
# 1. short cut if this layer is parsed before. # STEP: short cut if this layer is parsed before.
if self.context_name() in context: if self.context_name() in context:
if self.use_context_name(): if self.use_context_name():
return context[self.context_name()] return context[self.context_name()]
else: else:
return context[self.name] return context[self.name]
# 2. parse extra_parent that is not used by this layer but must # STEP: parse extra_parent that is not used by this layer but must
# be parsed before this layer. # be parsed before this layer.
for p in self.__extra_parent__: for p in self.__extra_parent__:
p.to_proto(context=context) p.to_proto(context=context)
# 3. parse parent that is used by this layer, get the result and # STEP: parse parent that is used by this layer, get the result and
# insert into kwargs of the next layer's to_proto_impl method. # insert into kwargs of the next layer's to_proto_impl method.
kwargs = dict() kwargs = dict()
for layer_name in self.__parent_layers__: for layer_name in self.__parent_layers__:
...@@ -112,13 +112,13 @@ class Layer(object): ...@@ -112,13 +112,13 @@ class Layer(object):
self.__parent_layers__[layer_name]) self.__parent_layers__[layer_name])
kwargs[layer_name] = v1_layer kwargs[layer_name] = v1_layer
# 4. parse myself and add myself into context. # STEP: parse myself and add myself into context.
ret_val = self.to_proto_impl(context=context, **kwargs) ret_val = self.to_proto_impl(**kwargs)
if self.context_name() is not None and self.context_name( if self.context_name() is not None \
) not in context: and self.context_name() not in context:
context[self.context_name()] = ret_val context[self.context_name()] = ret_val
# 5. parse children that should be pased after this layer. # STEP: parse children that should be pased after this layer.
for layer, pnames in self.__children_layers__: for layer, pnames in self.__children_layers__:
drop = False drop = False
...@@ -131,7 +131,7 @@ class Layer(object): ...@@ -131,7 +131,7 @@ class Layer(object):
continue continue
layer.to_proto(context=context) layer.to_proto(context=context)
# 6. return v1 layer result.g # STEP: return v1 layer result
if self.context_name() is None: if self.context_name() is None:
return ret_val return ret_val
elif self.use_context_name(): elif self.use_context_name():
...@@ -139,7 +139,7 @@ class Layer(object): ...@@ -139,7 +139,7 @@ class Layer(object):
else: else:
return context[self.name] return context[self.name]
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
raise NotImplementedError() raise NotImplementedError()
def context_name(self): def context_name(self):
...@@ -203,7 +203,7 @@ def __convert_to_v2__(method_name, ...@@ -203,7 +203,7 @@ def __convert_to_v2__(method_name,
if wrapper is not None: if wrapper is not None:
__init__ = wrapper(__init__) __init__ = wrapper(__init__)
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
......
...@@ -33,22 +33,25 @@ The primary usage shows below. ...@@ -33,22 +33,25 @@ The primary usage shows below.
import collections import collections
import inspect import inspect
from config_base import Layer, __convert_to_v2__ import re
import paddle.trainer_config_helpers as conf_helps import paddle.trainer_config_helpers as conf_helps
from paddle.trainer.config_parser import \
RecurrentLayerGroupWithoutOutLinksBegin, RecurrentLayerGroupSetOutLink, \
RecurrentLayerGroupEnd, model_type
from paddle.trainer_config_helpers.config_parser_utils import \ from paddle.trainer_config_helpers.config_parser_utils import \
parse_network_config as __parse__ parse_network_config as __parse__
from paddle.trainer_config_helpers.default_decorators import wrap_act_default from paddle.trainer_config_helpers.default_decorators import wrap_act_default
from paddle.trainer_config_helpers.default_decorators import \ from paddle.trainer_config_helpers.default_decorators import \
wrap_bias_attr_default wrap_bias_attr_default
from paddle.trainer_config_helpers.default_decorators import wrap_name_default from paddle.trainer_config_helpers.default_decorators import wrap_name_default
from paddle.trainer_config_helpers.layers import RecurrentLayerGroupSetGenerator, Generator
from paddle.trainer_config_helpers.layers import layer_support from paddle.trainer_config_helpers.layers import layer_support
from paddle.trainer.config_parser import \
RecurrentLayerGroupWithoutOutLinksBegin, RecurrentLayerGroupSetOutLink, \
RecurrentLayerGroupEnd, model_type
import activation import activation
import re import attr
import data_type import data_type
from config_base import Layer, __convert_to_v2__
__all__ = ['parse_network', 'data'] __all__ = ['parse_network', 'data']
...@@ -111,7 +114,7 @@ class DataLayerV2(Layer): ...@@ -111,7 +114,7 @@ class DataLayerV2(Layer):
super(DataLayerV2, self).__init__(name=name, parent_layers=dict()) super(DataLayerV2, self).__init__(name=name, parent_layers=dict())
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
args = dict() args = dict()
args['size'] = self.type.dim args['size'] = self.type.dim
for each in kwargs: for each in kwargs:
...@@ -134,6 +137,16 @@ class DataLayerV2(Layer): ...@@ -134,6 +137,16 @@ class DataLayerV2(Layer):
class MemoryV2(Layer): class MemoryV2(Layer):
def __init__(self, name, extra_input=None, **kwargs): def __init__(self, name, extra_input=None, **kwargs):
"""
Init memory object, if memory is inited inside recurrent_group step
function, it may depend on a boot_layer that should be initialized
outside recurrent_group, so we:
1. add RecurrentLayerInput to extra_parent of self.
2. add boot_layer to the extra_parent of RecurrentLayerInput.
:param extra_input: list of RecurrentLayerInput
:type extra_input: [RecurrentLayerInput]
"""
self.name = name self.name = name
super(MemoryV2, self).__init__(name=name, parent_layers=dict()) super(MemoryV2, self).__init__(name=name, parent_layers=dict())
self.__kwargs__ = kwargs self.__kwargs__ = kwargs
...@@ -164,7 +177,7 @@ class MemoryV2(Layer): ...@@ -164,7 +177,7 @@ class MemoryV2(Layer):
extra.append_extra_parent(kwargs['boot_layer']) extra.append_extra_parent(kwargs['boot_layer'])
self.__boot_layer_name__ = kwargs['boot_layer'].name self.__boot_layer_name__ = kwargs['boot_layer'].name
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
...@@ -172,7 +185,7 @@ class MemoryV2(Layer): ...@@ -172,7 +185,7 @@ class MemoryV2(Layer):
args[each] = self.__kwargs__[each] args[each] = self.__kwargs__[each]
if self.__boot_layer_name__ is not None: if self.__boot_layer_name__ is not None:
args['boot_layer'] = context[self.__boot_layer_name__] args['boot_layer'] = self.__context__[self.__boot_layer_name__]
size = args.get('size', None) size = args.get('size', None)
if size is not None: if size is not None:
...@@ -205,6 +218,66 @@ class StaticInputV2(object): ...@@ -205,6 +218,66 @@ class StaticInputV2(object):
# assert input.size is not None or size is not None # assert input.size is not None or size is not None
class BaseGeneratedInputV2(object):
def __init__(self):
self.bos_id = None
self.eos_id = None
def before_real_step(self):
raise NotImplementedError()
def after_real_step(self, *args):
raise NotImplementedError()
class GeneratedInputV2(BaseGeneratedInputV2):
def __init__(self, size, embedding_name, embedding_size):
super(GeneratedInputV2, self).__init__()
self.size = size
self.embedding_name = embedding_name
self.embedding_size = embedding_size
def after_real_step(self, input):
return max_id(input=input, name='__beam_search_predict__')
def before_real_step(self):
predict_id = memory(
name='__beam_search_predict__',
size=self.size,
boot_with_const_id=self.bos_id)
trg_emb = embedding(
input=predict_id,
size=self.embedding_size,
param_attr=attr.ParamAttr(name=self.embedding_name))
return trg_emb
class RecurrentLayerGroupSetGeneratorV2(Layer):
def __init__(self, eos_name, max_length, beam_size, num_results_per_sample):
self.eos_name = eos_name
self.max_length = max_length
self.beam_size = beam_size
self.num_results_per_sample = num_results_per_sample
super(RecurrentLayerGroupSetGeneratorV2, self).__init__(
name=eos_name, parent_layers={})
def to_proto_impl(self, **kwargs):
RecurrentLayerGroupSetGenerator(
Generator(
eos_layer_name=self.eos_name,
max_num_frames=self.max_length,
beam_size=self.beam_size,
num_results_per_sample=self.num_results_per_sample))
return self
def context_name(self):
return self.eos_name + ".fake"
def use_context_name(self):
return True
class MixedLayerV2(Layer): class MixedLayerV2(Layer):
""" """
This class is use to support `with` grammar. If not, the following code This class is use to support `with` grammar. If not, the following code
...@@ -254,7 +327,7 @@ class MixedLayerV2(Layer): ...@@ -254,7 +327,7 @@ class MixedLayerV2(Layer):
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
self.finalized = True self.finalized = True
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
args = dict() args = dict()
for each in kwargs: for each in kwargs:
args[each] = kwargs[each] args[each] = kwargs[each]
...@@ -300,7 +373,7 @@ class RecurrentLayerInput(Layer): ...@@ -300,7 +373,7 @@ class RecurrentLayerInput(Layer):
def context_name(self): def context_name(self):
return self.__recurrent_name__ + ".begin" return self.__recurrent_name__ + ".begin"
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
model_type('recurrent_nn') model_type('recurrent_nn')
RecurrentLayerGroupWithoutOutLinksBegin( RecurrentLayerGroupWithoutOutLinksBegin(
name=self.__recurrent_name__, name=self.__recurrent_name__,
...@@ -319,7 +392,7 @@ class RecurrentLayerOutput(Layer): ...@@ -319,7 +392,7 @@ class RecurrentLayerOutput(Layer):
def context_name(self): def context_name(self):
return self.__recurrent_name__ + ".end" return self.__recurrent_name__ + ".end"
def to_proto_impl(self, context=None, **kwargs): def to_proto_impl(self, **kwargs):
for l in self.__parents__: for l in self.__parents__:
RecurrentLayerGroupSetOutLink(l.name) RecurrentLayerGroupSetOutLink(l.name)
RecurrentLayerGroupEnd(name=self.__recurrent_name__) RecurrentLayerGroupEnd(name=self.__recurrent_name__)
...@@ -418,8 +491,7 @@ def recurrent_group(step, input, name=None): ...@@ -418,8 +491,7 @@ def recurrent_group(step, input, name=None):
size=static_input.input.calculate_size, size=static_input.input.calculate_size,
act=activation.Identity()) as mix: act=activation.Identity()) as mix:
mix += identity_projection(input=mem) mix += identity_projection(input=mem)
mem.append_child(layer=mix, parent_names=[mem.context_name()]) rnn_input.insert(input.index(static_input), mix)
rnn_input.insert(input.index(static_input), mem)
return step(*rnn_input) return step(*rnn_input)
actual_output = __real_step__(*actual_input) actual_output = __real_step__(*actual_input)
...@@ -440,6 +512,74 @@ def recurrent_group(step, input, name=None): ...@@ -440,6 +512,74 @@ def recurrent_group(step, input, name=None):
return retv return retv
@wrap_name_default()
def beam_search(step,
input,
bos_id,
eos_id,
beam_size,
max_length=500,
name=None,
num_results_per_sample=None):
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, StaticInputV2) or isinstance(
input, BaseGeneratedInputV2):
input = [input]
generated_input_index = -1
real_input = []
for i, each_input in enumerate(input):
assert isinstance(each_input, StaticInputV2) or isinstance(
each_input, BaseGeneratedInputV2)
if isinstance(each_input, BaseGeneratedInputV2):
assert generated_input_index == -1
generated_input_index = i
else:
real_input.append(each_input)
assert generated_input_index != -1
gipt = input[generated_input_index]
assert isinstance(gipt, BaseGeneratedInputV2)
gipt.bos_id = bos_id
gipt.eos_id = eos_id
def __real_step__(*args):
eos_name = "__%s_eos_layer__" % name
generator = RecurrentLayerGroupSetGeneratorV2(
eos_name, max_length, beam_size, num_results_per_sample)
args = list(args)
before_step_layer = gipt.before_real_step()
before_step_layer.append_child(
layer=generator, parent_names=[before_step_layer.name])
args.insert(generated_input_index, before_step_layer)
predict = gipt.after_real_step(step(*args))
eos_layer = eos(input=predict, eos_id=eos_id, name=eos_name)
predict.append_child(layer=eos_layer, parent_names=[predict.name])
return predict
# tmp = paddle.layer.recurrent_group(
# step=__real_step__,
# input=real_input,
# reverse=False,
# name=name,
# is_generating=True)
tmp = recurrent_group(
step=__real_step__, input=real_input, name=name)
return tmp
__projection_names__ = filter(lambda x: x.endswith('_projection'), __projection_names__ = filter(lambda x: x.endswith('_projection'),
dir(conf_helps)) dir(conf_helps))
......
import paddle.v2 as paddle
from paddle.v2.config_base import Layer
from paddle.trainer_config_helpers.default_decorators import wrap_name_default
from paddle.trainer_config_helpers.layers import RecurrentLayerGroupSetGenerator, Generator
class BaseGeneratedInputV2(object):
def __init__(self):
self.bos_id = None
self.eos_id = None
def before_real_step(self):
raise NotImplementedError()
def after_real_step(self, *args):
raise NotImplementedError()
class GeneratedInputV2(BaseGeneratedInputV2):
def __init__(self, size, embedding_name, embedding_size):
super(GeneratedInputV2, self).__init__()
self.size = size
self.embedding_name = embedding_name
self.embedding_size = embedding_size
def after_real_step(self, input):
return paddle.layer.max_id(input=input, name='__beam_search_predict__')
def before_real_step(self):
predict_id = paddle.layer.memory(
name='__beam_search_predict__',
size=self.size,
boot_with_const_id=self.bos_id)
trg_emb = paddle.layer.embedding(
input=predict_id,
size=self.embedding_size,
param_attr=paddle.attr.ParamAttr(name=self.embedding_name))
return trg_emb
class RecurrentLayerGroupSetGeneratorV2(Layer):
def __init__(self, eos_name, max_length, beam_size, num_results_per_sample):
self.eos_name = eos_name
self.max_length = max_length
self.beam_size = beam_size
self.num_results_per_sample = num_results_per_sample
super(RecurrentLayerGroupSetGeneratorV2, self).__init__(
name=eos_name, parent_layers={})
def to_proto_impl(self, context=None, **kwargs):
RecurrentLayerGroupSetGenerator(
Generator(
eos_layer_name=self.eos_name,
max_num_frames=self.max_length,
beam_size=self.beam_size,
num_results_per_sample=self.num_results_per_sample))
return self
def context_name(self):
return self.eos_name + ".fake"
def use_context_name(self):
return True
@wrap_name_default()
def beam_search(step,
input,
bos_id,
eos_id,
beam_size,
max_length=500,
name=None,
num_results_per_sample=None):
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(
input, BaseGeneratedInputV2):
input = [input]
generated_input_index = -1
real_input = []
for i, each_input in enumerate(input):
assert isinstance(each_input, paddle.layer.StaticInputV2) or isinstance(
each_input, BaseGeneratedInputV2)
if isinstance(each_input, BaseGeneratedInputV2):
assert generated_input_index == -1
generated_input_index = i
else:
real_input.append(each_input)
assert generated_input_index != -1
gipt = input[generated_input_index]
assert isinstance(gipt, BaseGeneratedInputV2)
gipt.bos_id = bos_id
gipt.eos_id = eos_id
def __real_step__(*args):
eos_name = "__%s_eos_layer__" % name
generator = RecurrentLayerGroupSetGeneratorV2(
eos_name, max_length, beam_size, num_results_per_sample)
args = list(args)
before_step_layer = gipt.before_real_step()
before_step_layer.append_child(
layer=generator, parent_names=[before_step_layer.name])
args.insert(generated_input_index, before_step_layer)
predict = gipt.after_real_step(step(*args))
eos = paddle.layer.eos(input=predict, eos_id=eos_id, name=eos_name)
predict.append_child(layer=eos, parent_names=[predict.name])
return predict
# tmp = paddle.layer.recurrent_group(
# step=__real_step__,
# input=real_input,
# reverse=False,
# name=name,
# is_generating=True)
tmp = paddle.layer.recurrent_group(
step=__real_step__, input=real_input, name=name)
return tmp
...@@ -8,8 +8,7 @@ packages=['paddle', ...@@ -8,8 +8,7 @@ packages=['paddle',
'paddle.v2', 'paddle.v2',
'paddle.v2.dataset', 'paddle.v2.dataset',
'paddle.v2.reader', 'paddle.v2.reader',
'paddle.v2.plot', 'paddle.v2.plot']
'paddle.v2.layers']
setup(name='paddle', setup(name='paddle',
version='${PADDLE_VERSION}', version='${PADDLE_VERSION}',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册