提交 a8d44cc5 编写于 作者: X xuwei06

Fix size for lstm, identity_projection and concat

上级 3070dd56
...@@ -482,7 +482,7 @@ def table_projection(input, size=0, param_attr=None): ...@@ -482,7 +482,7 @@ def table_projection(input, size=0, param_attr=None):
return proj return proj
def identity_projection(input, offset=None): def identity_projection(input, offset=None, size=None):
""" """
1. IdentityProjection if offset=None. It performs: 1. IdentityProjection if offset=None. It performs:
...@@ -523,8 +523,10 @@ def identity_projection(input, offset=None): ...@@ -523,8 +523,10 @@ def identity_projection(input, offset=None):
proj = IdentityProjection(input_layer_name=input.name) proj = IdentityProjection(input_layer_name=input.name)
proj.origin = input proj.origin = input
else: else:
if size is None:
size = input.size - offset
proj = IdentityOffsetProjection( proj = IdentityOffsetProjection(
input_layer_name=input.name, offset=offset) input_layer_name=input.name, offset=offset, size=size)
proj.origin = input proj.origin = input
return proj return proj
...@@ -2797,7 +2799,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): ...@@ -2797,7 +2799,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
if layer_type == LayerType.CONCAT_LAYER: if layer_type == LayerType.CONCAT_LAYER:
assert not bias_attr assert not bias_attr
Layer( layer = Layer(
name=name, name=name,
type=layer_type, type=layer_type,
inputs=[x.name for x in input] if is_concat_layer else input, inputs=[x.name for x in input] if is_concat_layer else input,
...@@ -2805,13 +2807,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): ...@@ -2805,13 +2807,7 @@ def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
sz = 0 sz = layer.config.size
for each_input in input:
if each_input.size is not None:
sz += each_input.size
else:
sz = None
break
return LayerOutput( return LayerOutput(
name, name,
...@@ -2979,7 +2975,7 @@ def memory(name, ...@@ -2979,7 +2975,7 @@ def memory(name,
@layer_support() @layer_support()
def lstm_step_layer(input, def lstm_step_layer(input,
state, state,
size, size=None,
act=None, act=None,
name=None, name=None,
gate_act=None, gate_act=None,
...@@ -3045,6 +3041,9 @@ def lstm_step_layer(input, ...@@ -3045,6 +3041,9 @@ def lstm_step_layer(input,
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
assert size is None or state.size == size
size = state.size
Layer( Layer(
name=name, name=name,
type=LayerType.LSTM_STEP_LAYER, type=LayerType.LSTM_STEP_LAYER,
...@@ -3052,7 +3051,7 @@ def lstm_step_layer(input, ...@@ -3052,7 +3051,7 @@ def lstm_step_layer(input,
active_gate_type=gate_act.name, active_gate_type=gate_act.name,
active_state_type=state_act.name, active_state_type=state_act.name,
bias=ParamAttr.to_bias(bias_attr), bias=ParamAttr.to_bias(bias_attr),
size=size, size=state.size,
inputs=[input.name, state.name], inputs=[input.name, state.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册