提交 ba51e6ea 编写于 作者: Q qiaolongfei

fix style problem

上级 d6e8d5cd
......@@ -131,8 +131,9 @@ def gru_encoder_decoder(data_conf,
decoder_group_name = "decoder_group"
group_inputs = [
StaticInput(input=encoded_vector, is_seq=True),
StaticInput(input=encoded_proj, is_seq=True)
StaticInput(
input=encoded_vector, is_seq=True), StaticInput(
input=encoded_proj, is_seq=True)
]
if not is_generating:
......
......@@ -114,7 +114,8 @@ class Layer(object):
# 4. parse myself and add myself into context.
ret_val = self.to_proto_impl(context=context, **kwargs)
if self.context_name() is not None and self.context_name() not in context:
if self.context_name() is not None and self.context_name(
) not in context:
context[self.context_name()] = ret_val
# 5. parse children that should be pased after this layer.
......
......@@ -292,7 +292,8 @@ class RecurrentLayerInput(Layer):
else:
self.__parents__ = parent_layers.values()[0]
self.__recurrent_name__ = recurrent_name
name = self.__parents__[index].name if index >= 0 else self.context_name()
name = self.__parents__[
index].name if index >= 0 else self.context_name()
super(RecurrentLayerInput, self).__init__(
name=name, parent_layers=parent_layers)
......@@ -402,9 +403,7 @@ def recurrent_group(step, input, name=None):
extra_input = None
if len(non_static_inputs) == 0:
extra_input = RecurrentLayerInput(
recurrent_name=name,
index=-1,
parent_layers={})
recurrent_name=name, index=-1, parent_layers={})
def __real_step__(*args):
rnn_input = list(args)
......
import beam_search
\ No newline at end of file
import beam_search
......@@ -63,6 +63,7 @@ class RecurrentLayerGroupSetGeneratorV2(Layer):
def use_context_name(self):
return True
@wrap_name_default()
def beam_search(step,
input,
......@@ -75,9 +76,10 @@ def beam_search(step,
if num_results_per_sample is None:
num_results_per_sample = beam_size
assert num_results_per_sample <= beam_size
# logger.warning("num_results_per_sample should be less than beam_size")
# logger.warning("num_results_per_sample should be less than beam_size")
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(input, BaseGeneratedInputV2):
if isinstance(input, paddle.layer.StaticInputV2) or isinstance(
input, BaseGeneratedInputV2):
input = [input]
generated_input_index = -1
......@@ -107,8 +109,8 @@ def beam_search(step,
args = list(args)
before_step_layer = gipt.before_real_step()
before_step_layer.append_child(layer=generator,
parent_names=[before_step_layer.name])
before_step_layer.append_child(
layer=generator, parent_names=[before_step_layer.name])
args.insert(generated_input_index, before_step_layer)
predict = gipt.after_real_step(step(*args))
......@@ -125,8 +127,6 @@ def beam_search(step,
# name=name,
# is_generating=True)
tmp = paddle.layer.recurrent_group(
step=__real_step__,
input=real_input,
name=name)
step=__real_step__, input=real_input, name=name)
return tmp
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册