提交 1906e63f 编写于 作者: X xzl

fix prelu(add filter_num output_x output_y) and add channel_shared param

上级 93c6e52a
......@@ -2052,9 +2052,15 @@ class ParameterReluLayer(LayerBase):
config_assert(len(self.inputs) == 1, "prelu layer has only one input.")
config_assert(input_layer.size % partial_sum == 0,
"a wrong setting for partial_sum")
dims = [1, input_layer.size / partial_sum]
self.set_layer_size(input_layer.size)
self.config.partial_sum = partial_sum
self.create_input_parameter(0, input_layer.size / partial_sum)
self.create_input_parameter(0, input_layer.size / partial_sum, dims)
self.set_layer_height_width(self.get_input_layer(0).height, \
self.get_input_layer(0).width)
self.set_layer_depth(self.get_input_layer(0).depth)
@config_layer('conv')
......
......@@ -6393,10 +6393,11 @@ def row_conv_layer(input,
@layer_support()
@wrap_name_default()
@wrap_param_attr_default()
def prelu_layer(input,
name=None,
partial_sum=1,
channel_shared=None,
num_channels=None,
param_attr=None,
layer_attr=None):
"""
......@@ -6427,6 +6428,10 @@ def prelu_layer(input,
- partial_sum = number of outputs, indicates all elements share the same weight.
:type partial_sum: int
:param channel_shared: whether or not the parameter are shared across channels.
- channel_shared = True, we set the partial_sum to the number of outputs.
- channel_shared = False, we set the partial_sum to the number of elements in one channel.
:type channel_shared: bool
:param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
......@@ -6437,7 +6442,22 @@ def prelu_layer(input,
"""
assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.'
assert isinstance(param_attr, ParameterAttribute)
if not param_attr:
param_attr = ParamAttr(initial_mean=0.25,
initial_std=0.0)
else:
assert isinstance(param_attr, ParameterAttribute)
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if channel_shared is not None:
assert isinstance(channel_shared, bool)
if channel_shared:
partial_sum = input.height * input.width * num_channels
else:
partial_sum = input.height * input.width
l = Layer(
name=name,
......@@ -6449,6 +6469,7 @@ def prelu_layer(input,
name=name,
layer_type=LayerType.PRELU,
parents=input,
num_filters = num_channels,
size=l.config.size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册