提交 98bc889c 编写于 作者: H Haonan 提交者: emailweixu

split the input list of conv_operator into two inputs: image and filter (#104)

上级 b130ba73
...@@ -2667,7 +2667,7 @@ def classification_cost(input, label, name=None, ...@@ -2667,7 +2667,7 @@ def classification_cost(input, label, name=None,
return LayerOutput(name, LayerType.COST, parents=[input, label]) return LayerOutput(name, LayerType.COST, parents=[input, label])
def conv_operator(input, filter_size, num_filters, def conv_operator(img, filter, filter_size, num_filters,
num_channel=None, stride=1, padding=0, groups=1, num_channel=None, stride=1, padding=0, groups=1,
filter_size_y=None, stride_y=None, padding_y=None): filter_size_y=None, stride_y=None, padding_y=None):
""" """
...@@ -2680,13 +2680,16 @@ def conv_operator(input, filter_size, num_filters, ...@@ -2680,13 +2680,16 @@ def conv_operator(input, filter_size, num_filters,
.. code-block:: python .. code-block:: python
op = conv_operator(input=[layer1, layer2], op = conv_operator(img=input1,
filter=input2,
filter_size=3.0, filter_size=3.0,
num_filters=64, num_filters=64,
num_channels=64) num_channels=64)
:param input: Input layer. :param img: input image
:type input: LayerOutput|list|tuple :type img: LayerOutput
:param filter: input filter
:type filter: LayerOutput
:param filter_size: The x dimension of a filter kernel. :param filter_size: The x dimension of a filter kernel.
:type filter_size: int :type filter_size: int
:param filter_size_y: The y dimension of a filter kernel. Since :param filter_size_y: The y dimension of a filter kernel. Since
...@@ -2708,14 +2711,13 @@ def conv_operator(input, filter_size, num_filters, ...@@ -2708,14 +2711,13 @@ def conv_operator(input, filter_size, num_filters,
:return: A ConvOperator Object. :return: A ConvOperator Object.
:rtype: ConvOperator :rtype: ConvOperator
""" """
assert isinstance(input, list) or isinstance(input, tuple)
if filter_size_y is None: if filter_size_y is None:
filter_size_y = filter_size filter_size_y = filter_size
if stride_y is None: if stride_y is None:
stride_y = stride stride_y = stride
if padding_y is None: if padding_y is None:
padding_y = padding padding_y = padding
op = ConvOperator(input_layer_names=[x.name for x in input], op = ConvOperator(input_layer_names=[img.name, filter.name],
num_filters = num_filter, num_filters = num_filter,
conv_conf=Conv(filter_size=filter_size, conv_conf=Conv(filter_size=filter_size,
padding=padding, padding=padding,
...@@ -2725,7 +2727,7 @@ def conv_operator(input, filter_size, num_filters, ...@@ -2725,7 +2727,7 @@ def conv_operator(input, filter_size, num_filters,
padding_y=padding_y, padding_y=padding_y,
stride_y=stride_y, stride_y=stride_y,
groups=groups)) groups=groups))
op.origin = input op.origin = [img, filter]
op.origin.operator = "conv_op" op.origin.operator = "conv_op"
return op return op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册