提交 8d4c453b 编写于 作者: H Haonan 提交者: emailweixu

set mixedlayer output size according to input operator (#414)

* set mixedlayer output size according to input operator
* change from num_channel to num_channels for conv_operator (the old one is
really misleading because all the others are num_channels)

* also changed the arg name in projections.py
上级 5ccf84ab
...@@ -590,7 +590,7 @@ class MixedLayerType(LayerOutput): ...@@ -590,7 +590,7 @@ class MixedLayerType(LayerOutput):
def __exit__(self, *args, **kwargs): def __exit__(self, *args, **kwargs):
del args, kwargs # unused parameter to suppress warning del args, kwargs # unused parameter to suppress warning
assert len(self.inputs) != 0 assert len(self.inputs) != 0
MixedLayer( ml = MixedLayer(
name=self.name, name=self.name,
size=self.size, size=self.size,
active_type=self.activation.name, active_type=self.activation.name,
...@@ -598,6 +598,9 @@ class MixedLayerType(LayerOutput): ...@@ -598,6 +598,9 @@ class MixedLayerType(LayerOutput):
inputs=self.inputs, inputs=self.inputs,
**ExtraLayerAttribute.to_kwargs(self.layer_attr) **ExtraLayerAttribute.to_kwargs(self.layer_attr)
) )
# update the size which might be computed inside MixedLayer
# according to the operator's output size
self.size = ml.config.size
@wrap_name_default("mixed") @wrap_name_default("mixed")
...@@ -2623,7 +2626,7 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None): ...@@ -2623,7 +2626,7 @@ def out_prod_layer(input1, input2, name=None, layer_attr=None):
assert isinstance(input1, LayerOutput) assert isinstance(input1, LayerOutput)
assert isinstance(input2, LayerOutput) assert isinstance(input2, LayerOutput)
Layer(name=name, Layer(name=name,
type="out_prod", type=LayerType.OUT_PROD_LAYER,
inputs=[input1.name, input2.name], inputs=[input1.name, input2.name],
**ExtraLayerAttribute.to_kwargs(layer_attr)) **ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name=name, return LayerOutput(name=name,
...@@ -2884,7 +2887,7 @@ def classification_cost(input, label, weight=None, name=None, ...@@ -2884,7 +2887,7 @@ def classification_cost(input, label, weight=None, name=None,
def conv_operator(img, filter, filter_size, num_filters, def conv_operator(img, filter, filter_size, num_filters,
num_channel=None, stride=1, padding=0, num_channels=None, stride=1, padding=0,
filter_size_y=None, stride_y=None, padding_y=None): filter_size_y=None, stride_y=None, padding_y=None):
""" """
Different from img_conv_layer, conv_op is an Operator, which can be used Different from img_conv_layer, conv_op is an Operator, which can be used
...@@ -2914,8 +2917,8 @@ def conv_operator(img, filter, filter_size, num_filters, ...@@ -2914,8 +2917,8 @@ def conv_operator(img, filter, filter_size, num_filters,
:type filter_size_y: int :type filter_size_y: int
:param num_filters: channel of output data. :param num_filters: channel of output data.
:type num_filters: int :type num_filters: int
:param num_channel: channel of input data. :param num_channels: channel of input data.
:type num_channel: int :type num_channels: int
:param stride: The x dimension of the stride. :param stride: The x dimension of the stride.
:type stride: int :type stride: int
:param stride_y: The y dimension of the stride. :param stride_y: The y dimension of the stride.
...@@ -2934,19 +2937,19 @@ def conv_operator(img, filter, filter_size, num_filters, ...@@ -2934,19 +2937,19 @@ def conv_operator(img, filter, filter_size, num_filters,
if padding_y is None: if padding_y is None:
padding_y = padding padding_y = padding
if num_channel is None: if num_channels is None:
num_channel = img.num_filters num_channels = img.num_filters
assert isinstance(filter, LayerOutput) assert isinstance(filter, LayerOutput)
if filter.size is not None: if filter.size is not None:
filter.size = filter_size * filter_size_y * num_filters * num_channel filter.size = filter_size * filter_size_y * num_filters * num_channels
op = ConvOperator(input_layer_names=[img.name, filter.name], op = ConvOperator(input_layer_names=[img.name, filter.name],
num_filters=num_filters, num_filters=num_filters,
conv_conf=Conv(filter_size=filter_size, conv_conf=Conv(filter_size=filter_size,
padding=padding, padding=padding,
stride=stride, stride=stride,
channels=num_channel, channels=num_channels,
filter_size_y=filter_size_y, filter_size_y=filter_size_y,
padding_y=padding_y, padding_y=padding_y,
stride_y=stride_y, stride_y=stride_y,
...@@ -2986,8 +2989,8 @@ def conv_projection(input, filter_size, num_filters, ...@@ -2986,8 +2989,8 @@ def conv_projection(input, filter_size, num_filters,
:type filter_size_y: int :type filter_size_y: int
:param num_filters: channel of output data. :param num_filters: channel of output data.
:type num_filters: int :type num_filters: int
:param num_channel: channel of input data. :param num_channels: channel of input data.
:type num_channel: int :type num_channels: int
:param stride: The x dimension of the stride. :param stride: The x dimension of the stride.
:type stride: int :type stride: int
:param stride_y: The y dimension of the stride. :param stride_y: The y dimension of the stride.
......
...@@ -35,7 +35,7 @@ flt = data_layer(name='filter', size=3*3*1*64) ...@@ -35,7 +35,7 @@ flt = data_layer(name='filter', size=3*3*1*64)
with mixed_layer() as m7: with mixed_layer() as m7:
m7 += conv_operator(img=img, filter=flt, num_filters=64, m7 += conv_operator(img=img, filter=flt, num_filters=64,
num_channel=1, filter_size=3) num_channels=1, filter_size=3)
end = mixed_layer(input=[full_matrix_projection(input=m5), end = mixed_layer(input=[full_matrix_projection(input=m5),
trans_full_matrix_projection(input=m6), trans_full_matrix_projection(input=m6),
......
...@@ -29,9 +29,11 @@ z1 = mixed_layer(act=LinearActivation(), ...@@ -29,9 +29,11 @@ z1 = mixed_layer(act=LinearActivation(),
filter=y1, filter=y1,
filter_size=1, filter_size=1,
num_filters=5, num_filters=5,
num_channel=5, num_channels=5,
stride=1)]) stride=1)])
assert z1.size > 0
y2 = fc_layer(input=y, size=15) y2 = fc_layer(input=y, size=15)
cos1 = cos_sim(a=x1, b=y1) cos1 = cos_sim(a=x1, b=y1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册