prelu layer's insufficient in python api
Created by: NHZlX
The channel, height, width of output of prelu layer should be the same with the input. But, it was not specified in Python API.
so, it will be error in the following two situations:
First (No input channel specified after prelu)
#!/usr/bin/env python
# coding=utf-8
import paddle.v2 as paddle
def conv_layer(input,
ch_out,
filter_size,
stride,
padding = 0,
ch_in=None):
tmp = paddle.layer.img_conv(
input=input,
filter_size=filter_size,
num_channels=ch_in,
num_filters=ch_out,
stride=stride,
padding=padding)
tmp = paddle.layer.prelu(tmp, partial_sum = 1)
return tmp
data_size = 3 * 224 * 224
img = paddle.layer.data(name="image", type=paddle.data_type.dense_vector(data_size))
conv1 = conv_layer(img, 3, 1, 1, 0, ch_in = 3)
conv2 = conv_layer(conv1, 13, 3, 2, 1)
error:
Traceback (most recent call last):
File "test1.py", line 33, in <module>
conv2 = conv_layer(conv1, 13, 3, 2, 1)
File "test1.py", line 19, in conv_layer
padding=padding)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/v2/config_base.py", line 52, in wrapped
out = f(*args, **xargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/layers.py", line 403, in wrapper
return method(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/layers.py", line 2487, in img_conv_layer
assert input.num_filters is not None
AssertionError
second
#!/usr/bin/env python
# coding=utf-8
import paddle.v2 as paddle
def conv_layer(input,
ch_out,
filter_size,
stride,
padding = 0,
ch_in=None):
conv = paddle.layer.img_conv(
input=input,
filter_size=filter_size,
num_channels=ch_in,
num_filters=ch_out,
stride=stride,
padding=padding)
tmp = paddle.layer.prelu(conv, partial_sum = 1)
tmp = paddle.layer.addto(input=[conv, tmp])
return tmp
data_size = 3 * 224 * 224
img = paddle.layer.data(name="image", type=paddle.data_type.dense_vector(data_size))
conv1 = conv_layer(img, 3, 1, 1, 0, ch_in = 3)
error:
Traceback (most recent call last):
File "test1.py", line 28, in <module>
conv1 = conv_layer(img, 3, 1, 1, 0, ch_in = 3)
File "test1.py", line 22, in conv_layer
tmp = paddle.layer.addto(input=[conv, tmp])
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/v2/config_base.py", line 52, in wrapped
out = f(*args, **xargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/default_decorators.py", line 53, in __wrapper__
return func(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/layers.py", line 403, in wrapper
return method(*args, **kwargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer_config_helpers/layers.py", line 3315, in addto_layer
**ExtraLayerAttribute.to_kwargs(layer_attr))
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer/config_parser.py", line 3834, in Layer
return layer_func(name, **xargs)
File "/home/xingzhaolong/.jumbo/lib/python2.7/site-packages/paddle/trainer/config_parser.py", line 2819, in __init__
input_index).height
AssertionError