提交 dc9f31b3 编写于 作者: H hedaoyuan

Add SliceProjection and slice_projection.

上级 b9767aea
......@@ -225,7 +225,8 @@ message ProjectionConfig {
optional PoolConfig pool_conf = 12;
// For slice
repeated SliceConfig slice = 13;
// Each slice output is the input[start, end)
repeated SliceConfig slices = 13;
}
message OperatorConfig {
......
......@@ -565,6 +565,35 @@ class IdentityOffsetProjection(Projection):
return []
@config_class
class SliceProjection(Projection):
type = 'slice'
def __init__(self, input_layer_name, slices, **xargs):
super(SliceProjection, self).__init__(input_layer_name, **xargs)
input = g_layer_map[input_layer_name]
if input.type in ["exconv", "cudnn_conv"]:
# the slice operator is for the channel dimension
assert input.num_filters is not None
channels = input.num_filters
image_size = input.size / channels
assert slices[len(slices) - 1][1] <= channels
for i in xrange(len(slices)):
slice = self.proj_conf.slices.add()
slice.start = slices[i][0] * image_size
slice.end = slices[i][1] * image_size
self.size += slice.end - slice.start
else:
config_assert(False,
'Currently the input should be convolution layer')
def calc_parameter_size(self, input_size, output_size):
return 0
def calc_parameter_dims(self, input_size, output_size):
return []
# DotMulProjection performs element-wise multiplication with weight
@config_class
class DotMulProjection(Projection):
......
......@@ -128,6 +128,7 @@ __all__ = [
'prelu_layer',
'gated_unit_layer',
'crop_layer',
'slice_projection',
]
......@@ -536,6 +537,45 @@ def identity_projection(input, offset=None, size=None):
return proj
def slice_projection(input, slices):
"""
slice_projection can get multiple outputs, and each output is a slice
of the input.
.. math::
output[i] = input.slice(slices[i])
The example usage is:
.. code-block:: python
proj = slice_projection(input=layer, slices=[(0, 10), (20, 30)])
Note that slice_projection should not have any parameter.
:param input: Input Layer.
:type input: LayerOutput
:param slices: An array of slice parameters.
Each slice contains the start and end offsets based
on the input.
:type offset: pair of int
:return: A SliceProjection object
:rtype: SliceProjection
"""
assert len(slices) >= 1
start = 0
for i in xrange(len(slices)):
assert len(slices[i]) == 2
# The start position of the next slice needs to be greater than
# or equal to the end position of the previous slice.
assert slices[i][0] >= start
assert slices[i][1] >= slices[i][0]
start = slices[i][1]
proj = SliceProjection(input_layer_name=input.name, slices=slices)
proj.origin = input
return proj
@wrap_param_attr_default()
def scaling_projection(input, param_attr=None):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册