提交 d7b80f03 编写于 作者: X xuwei06

Correctly handle width and height for some layers

上级 0d1715da
......@@ -338,7 +338,8 @@ def RecurrentLayerGroupWithoutOutLinksBegin(name,
in_links_count += 1
layer_name = MakeLayerNameInParentSubmodel(name)
layer = g_layer_map[layer_name]
ScatterAgentLayer(name=name, size=layer.size)
ScatterAgentLayer(
name=name, size=layer.size, width=layer.width, height=layer.height)
pair = g_current_submodel.in_links.add()
pair.layer_name = layer_name
......@@ -2197,8 +2198,8 @@ class MaxOutLayer(LayerBase):
maxout_conf = self.config.inputs[0].maxout_conf
parse_maxout(self.inputs[0].maxout, input_layer.name, maxout_conf)
out_channels = maxout_conf.image_conf.channels / maxout_conf.groups
self.set_cnn_layer(name, g_layer_map[input_layer.name].height,
g_layer_map[input_layer.name].width, out_channels)
self.set_cnn_layer(name, maxout_conf.image_conf.img_size_y,
maxout_conf.image_conf.img_size, out_channels)
@config_layer('row_conv')
......@@ -2405,9 +2406,11 @@ class GatherAgentLayer(LayerBase):
@config_layer('scatter_agent')
class ScatterAgentLayer(LayerBase):
def __init__(self, name, size, device=None):
def __init__(self, name, size, width=None, height=None, device=None):
super(ScatterAgentLayer, self).__init__(
name, 'scatter_agent', size, inputs=[], device=device)
if height and width:
self.set_layer_height_width(height, width)
@config_layer('multiplex')
......
......@@ -16,11 +16,13 @@ import functools
import collections
import inspect
import paddle.trainer.config_parser as cp
from paddle.trainer.config_parser import *
from .activations import LinearActivation, SigmoidActivation, TanhActivation, \
ReluActivation, IdentityActivation, SoftmaxActivation, BaseActivation
from .evaluators import *
from .poolings import MaxPooling, AvgPooling, BasePoolingType
from .poolings import MaxPooling, AvgPooling, BasePoolingType, \
CudnnAvgPooling, CudnnMaxPooling
from .attrs import *
from .default_decorators import *
......@@ -330,6 +332,14 @@ class LayerOutput(object):
self.outputs = outputs
self.reverse = reverse
@property
def width(self):
return cp.g_layer_map[self.full_name].width
@property
def height(self):
return cp.g_layer_map[self.full_name].height
def set_input(self, input):
"""
Set the input for a memory layer. Can only be used for memory layer
......@@ -911,7 +921,13 @@ def data_layer(name, size, height=None, width=None, layer_attr=None):
width=width,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(name, LayerType.DATA, size=size)
num_filters = None
if height is not None and width is not None:
num_filters = size / (width * height)
assert num_filters * width * height == size, \
"size=%s width=%s height=%s" % (size, width, height)
return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
@wrap_name_default("embedding")
......@@ -2571,6 +2587,10 @@ def img_pool_layer(input,
assert input.num_filters is not None
num_channels = input.num_filters
assert type(pool_type) in [AvgPooling, MaxPooling, CudnnAvgPooling,
CudnnMaxPooling], \
"only AvgPooling and MaxPooling are supported"
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
......@@ -2580,7 +2600,6 @@ def img_pool_layer(input,
if (
isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)) \
else pool_type.name
pool_size_y = pool_size if pool_size_y is None else pool_size_y
stride_y = stride if stride_y is None else stride_y
padding_y = padding if padding_y is None else padding_y
......@@ -4204,8 +4223,7 @@ def conv_operator(img,
num_channels = img.num_filters
assert isinstance(filter, LayerOutput)
if filter.size is not None:
filter.size = filter_size * filter_size_y * num_filters * num_channels
assert filter.size is not None
opCls = ConvTransOperator if trans else ConvOperator
......@@ -4916,7 +4934,6 @@ def maxout_layer(input, groups, num_channels=None, name=None, layer_attr=None):
:return: LayerOutput object.
:rtype: LayerOutput
"""
assert input.layer_type == LayerType.CONV_LAYER
assert isinstance(input.activation, LinearActivation)
assert groups > 1
if num_channels is None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册