未验证 提交 51e7c26f 编写于 作者: Z Zhaolong Xing 提交者: GitHub

Merge pull request #5641 from NHZlX/fix_prelu

Fix prelu python api 
...@@ -2037,13 +2037,20 @@ class ParameterReluLayer(LayerBase): ...@@ -2037,13 +2037,20 @@ class ParameterReluLayer(LayerBase):
def __init__(self, name, inputs, partial_sum=1, **args): def __init__(self, name, inputs, partial_sum=1, **args):
super(ParameterReluLayer, self).__init__( super(ParameterReluLayer, self).__init__(
name, self.layer_type, 0, inputs=inputs, **args) name, self.layer_type, 0, inputs=inputs, **args)
input_layer = self.get_input_layer(0) input_layer = self.get_input_layer(0)
config_assert(len(self.inputs) == 1, "prelu layer has only one input.") config_assert(len(self.inputs) == 1, "prelu layer has only one input.")
config_assert(input_layer.size % partial_sum == 0, config_assert(input_layer.size % partial_sum == 0,
"a wrong setting for partial_sum") "a wrong setting for partial_sum")
dims = [1, input_layer.size / partial_sum]
self.set_layer_size(input_layer.size) self.set_layer_size(input_layer.size)
self.config.partial_sum = partial_sum self.config.partial_sum = partial_sum
self.create_input_parameter(0, input_layer.size / partial_sum) self.create_input_parameter(0, input_layer.size / partial_sum, dims)
self.set_layer_height_width(self.get_input_layer(0).height, \
self.get_input_layer(0).width)
self.set_layer_depth(self.get_input_layer(0).depth)
@config_layer('conv') @config_layer('conv')
......
...@@ -6604,10 +6604,11 @@ def row_conv_layer(input, ...@@ -6604,10 +6604,11 @@ def row_conv_layer(input,
@layer_support() @layer_support()
@wrap_name_default() @wrap_name_default()
@wrap_param_attr_default()
def prelu_layer(input, def prelu_layer(input,
name=None, name=None,
partial_sum=1, partial_sum=1,
channel_shared=None,
num_channels=None,
param_attr=None, param_attr=None,
layer_attr=None): layer_attr=None):
""" """
...@@ -6638,6 +6639,12 @@ def prelu_layer(input, ...@@ -6638,6 +6639,12 @@ def prelu_layer(input,
- partial_sum = number of outputs, indicates all elements share the same weight. - partial_sum = number of outputs, indicates all elements share the same weight.
:type partial_sum: int :type partial_sum: int
:param channel_shared: whether or not the parameter are shared across channels.
- channel_shared = True, we set the partial_sum to the number of outputs.
- channel_shared = False, we set the partial_sum to the number of elements in one channel.
:type channel_shared: bool
:param num_channels: number of input channel.
:type num_channels: int
:param param_attr: The parameter attribute. See ParameterAttribute for details. :param param_attr: The parameter attribute. See ParameterAttribute for details.
:type param_attr: ParameterAttribute :type param_attr: ParameterAttribute
:param layer_attr: The extra layer attribute. See ExtraLayerAttribute for :param layer_attr: The extra layer attribute. See ExtraLayerAttribute for
...@@ -6648,7 +6655,25 @@ def prelu_layer(input, ...@@ -6648,7 +6655,25 @@ def prelu_layer(input,
""" """
assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.' assert isinstance(input, LayerOutput), 'prelu_layer accepts only one input.'
assert isinstance(param_attr, ParameterAttribute)
if not param_attr:
param_attr = ParamAttr(initial_mean=0.25, initial_std=0.0)
else:
assert isinstance(param_attr, ParameterAttribute)
if num_channels is None:
assert input.num_filters is not None, \
'the input channel cannot be detected, please specify the num_channels parameter'
num_channels = input.num_filters
if channel_shared is not None:
assert isinstance(channel_shared, bool)
assert (input.height != 0 and input.width != 0), \
'input height and widht must be setted'
if channel_shared:
partial_sum = input.height * input.width * num_channels
else:
partial_sum = input.height * input.width
l = Layer( l = Layer(
name=name, name=name,
...@@ -6660,6 +6685,7 @@ def prelu_layer(input, ...@@ -6660,6 +6685,7 @@ def prelu_layer(input,
name=name, name=name,
layer_type=LayerType.PRELU, layer_type=LayerType.PRELU,
parents=input, parents=input,
num_filters=num_channels,
size=l.config.size) size=l.config.size)
......
...@@ -4,6 +4,8 @@ layers { ...@@ -4,6 +4,8 @@ layers {
type: "data" type: "data"
size: 300 size: 300
active_type: "" active_type: ""
height: 10
width: 10
} }
layers { layers {
name: "__prelu_layer_0__" name: "__prelu_layer_0__"
...@@ -15,6 +17,9 @@ layers { ...@@ -15,6 +17,9 @@ layers {
input_parameter_name: "___prelu_layer_0__.w0" input_parameter_name: "___prelu_layer_0__.w0"
} }
partial_sum: 1 partial_sum: 1
height: 10
width: 10
depth: 1
} }
layers { layers {
name: "__prelu_layer_1__" name: "__prelu_layer_1__"
...@@ -26,6 +31,9 @@ layers { ...@@ -26,6 +31,9 @@ layers {
input_parameter_name: "___prelu_layer_1__.w0" input_parameter_name: "___prelu_layer_1__.w0"
} }
partial_sum: 1 partial_sum: 1
height: 10
width: 10
depth: 1
} }
layers { layers {
name: "__prelu_layer_2__" name: "__prelu_layer_2__"
...@@ -37,41 +45,100 @@ layers { ...@@ -37,41 +45,100 @@ layers {
input_parameter_name: "___prelu_layer_2__.w0" input_parameter_name: "___prelu_layer_2__.w0"
} }
partial_sum: 5 partial_sum: 5
height: 10
width: 10
depth: 1
}
layers {
name: "__prelu_layer_3__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_3__.w0"
}
partial_sum: 300
height: 10
width: 10
depth: 1
}
layers {
name: "__prelu_layer_4__"
type: "prelu"
size: 300
active_type: ""
inputs {
input_layer_name: "input"
input_parameter_name: "___prelu_layer_4__.w0"
}
partial_sum: 100
height: 10
width: 10
depth: 1
} }
parameters { parameters {
name: "___prelu_layer_0__.w0" name: "___prelu_layer_0__.w0"
size: 300 size: 300
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.057735026919 initial_std: 0.0
dims: 1
dims: 300
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
parameters { parameters {
name: "___prelu_layer_1__.w0" name: "___prelu_layer_1__.w0"
size: 300 size: 300
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.057735026919 initial_std: 0.0
dims: 1
dims: 300
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
parameters { parameters {
name: "___prelu_layer_2__.w0" name: "___prelu_layer_2__.w0"
size: 60 size: 60
initial_mean: 0.0 initial_mean: 0.25
initial_std: 0.129099444874 initial_std: 0.0
dims: 1
dims: 60
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___prelu_layer_3__.w0"
size: 1
initial_mean: 0.25
initial_std: 0.0
dims: 1
dims: 1
initial_strategy: 0
initial_smart: false
}
parameters {
name: "___prelu_layer_4__.w0"
size: 3
initial_mean: 0.25
initial_std: 0.0
dims: 1
dims: 3
initial_strategy: 0 initial_strategy: 0
initial_smart: true initial_smart: false
} }
input_layer_names: "input" input_layer_names: "input"
output_layer_names: "__prelu_layer_2__" output_layer_names: "__prelu_layer_4__"
sub_models { sub_models {
name: "root" name: "root"
layer_names: "input" layer_names: "input"
layer_names: "__prelu_layer_0__" layer_names: "__prelu_layer_0__"
layer_names: "__prelu_layer_1__" layer_names: "__prelu_layer_1__"
layer_names: "__prelu_layer_2__" layer_names: "__prelu_layer_2__"
layer_names: "__prelu_layer_3__"
layer_names: "__prelu_layer_4__"
input_layer_names: "input" input_layer_names: "input"
output_layer_names: "__prelu_layer_2__" output_layer_names: "__prelu_layer_4__"
is_recurrent_layer_group: false is_recurrent_layer_group: false
} }
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
data = data_layer(name='input', size=300) data = data_layer(name='input', size=300, height=10, width=10)
prelu = prelu_layer(input=data) prelu = prelu_layer(input=data, num_channels=3)
prelu = prelu_layer(input=data, partial_sum=1) prelu = prelu_layer(input=data, partial_sum=1, num_channels=3)
prelu = prelu_layer(input=data, partial_sum=5) prelu = prelu_layer(input=data, partial_sum=5, num_channels=3)
prelu = prelu_layer(input=data, channel_shared=True, num_channels=3)
prelu = prelu_layer(input=data, channel_shared=False, num_channels=3)
outputs(prelu) outputs(prelu)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册