From 98bc889cb5d97215c62300bf30da45165d34b175 Mon Sep 17 00:00:00 2001 From: Haonan Date: Wed, 21 Sep 2016 14:00:57 -0700 Subject: [PATCH] split the input list of conv_operator into two inputs: image and filter (#104) --- python/paddle/trainer_config_helpers/layers.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index a5bacaf07..9963b3813 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -2667,7 +2667,7 @@ def classification_cost(input, label, name=None, return LayerOutput(name, LayerType.COST, parents=[input, label]) -def conv_operator(input, filter_size, num_filters, +def conv_operator(img, filter, filter_size, num_filters, num_channel=None, stride=1, padding=0, groups=1, filter_size_y=None, stride_y=None, padding_y=None): """ @@ -2680,13 +2680,16 @@ def conv_operator(input, filter_size, num_filters, .. code-block:: python - op = conv_operator(input=[layer1, layer2], + op = conv_operator(img=input1, + filter=input2, filter_size=3.0, num_filters=64, num_channels=64) - :param input: Input layer. - :type input: LayerOutput|list|tuple + :param img: input image + :type img: LayerOutput + :param filter: input filter + :type filter: LayerOutput :param filter_size: The x dimension of a filter kernel. :type filter_size: int :param filter_size_y: The y dimension of a filter kernel. Since @@ -2708,14 +2711,13 @@ def conv_operator(input, filter_size, num_filters, :return: A ConvOperator Object. :rtype: ConvOperator """ - assert isinstance(input, list) or isinstance(input, tuple) if filter_size_y is None: filter_size_y = filter_size if stride_y is None: stride_y = stride if padding_y is None: padding_y = padding - op = ConvOperator(input_layer_names=[x.name for x in input], + op = ConvOperator(input_layer_names=[img.name, filter.name], num_filters = num_filter, conv_conf=Conv(filter_size=filter_size, padding=padding, @@ -2725,7 +2727,7 @@ def conv_operator(input, filter_size, num_filters, padding_y=padding_y, stride_y=stride_y, groups=groups)) - op.origin = input + op.origin = [img, filter] op.origin.operator = "conv_op" return op -- GitLab