You need to sign in or sign up before continuing.
提交 91a0c11b 编写于 作者: C chengduoZH

Adaptive data structure for SwitchOrderLayer

上级 544458e0
...@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap, ...@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */ /* Initialize the basic parent class */
Layer::init(layerMap, parameterMap); Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf(); auto& img_conf = config_.inputs(0).image_conf();
size_t inD = img_conf.img_size_z();
size_t inH = size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size(); size_t inW = img_conf.img_size();
size_t inC = img_conf.channels(); size_t inC = img_conf.channels();
inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW}); inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4); outDims_ = TensorShape(4);
...@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() { ...@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue(); MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight(); size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize); inDims_.setDim(0, batchSize);
int d = inputLayers_[0]->getOutput().getFrameDepth();
d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight(); int h = inputLayers_[0]->getOutput().getFrameHeight();
if (h != 0) inDims_.setDim(2, h); if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth(); int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w); if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt(); int totalCount = input->getElementCnt();
......
...@@ -271,6 +271,7 @@ message ImageConfig { ...@@ -271,6 +271,7 @@ message ImageConfig {
// The size of input feature map. // The size of input feature map.
required uint32 img_size = 8; required uint32 img_size = 8;
optional uint32 img_size_y = 9; optional uint32 img_size_y = 9;
optional uint32 img_size_z = 10 [ default = 1 ];
} }
message PriorBoxConfig { message PriorBoxConfig {
......
...@@ -6410,7 +6410,7 @@ def gated_unit_layer(input, ...@@ -6410,7 +6410,7 @@ def gated_unit_layer(input,
@wrap_name_default('switch_order') @wrap_name_default('switch_order')
def switch_order_layer(input, def switch_order_layer(input,
name=None, name=None,
reshape=None, reshape_axis=None,
act=None, act=None,
layer_attr=None): layer_attr=None):
""" """
...@@ -6421,8 +6421,9 @@ def switch_order_layer(input, ...@@ -6421,8 +6421,9 @@ def switch_order_layer(input,
The example usage is: The example usage is:
.. code-block:: python .. code-block:: python
reshape_axis = 3
switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis)
reshape = {'height':[ 0, 1, 2], 'width':[3]} reshape = {'height':[ 0, 1, 2], 'width':[3]}
switch = switch_order(input=layer, name='switch', reshape=reshape)
:param input: The input layer. :param input: The input layer.
:type input: LayerOutput :type input: LayerOutput
...@@ -6434,6 +6435,11 @@ def switch_order_layer(input, ...@@ -6434,6 +6435,11 @@ def switch_order_layer(input,
:rtype: LayerOutput :rtype: LayerOutput
""" """
assert isinstance(input, LayerOutput) assert isinstance(input, LayerOutput)
assert reshape_axis != None and (reshape_axis > 0 and reshape_axis < 4)
height = [ele for ele in xrange(reshape_axis)]
width = [ele for ele in range(reshape_axis, 4)]
reshape = {'height': height, 'width': width}
l = Layer( l = Layer(
name=name, name=name,
inputs=input.name, inputs=input.name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册