diff --git a/paddle/function/CMakeLists.txt b/paddle/function/CMakeLists.txt index f43f15e5cacb70b625d7791e1e02ce7780286200..4fd72d64a90ae6f16dd1499ceb7fba6e40fe4cea 100644 --- a/paddle/function/CMakeLists.txt +++ b/paddle/function/CMakeLists.txt @@ -44,6 +44,7 @@ if(WITH_GPU) add_simple_unittest(RowConvOpTest) add_simple_unittest(BlockExpandOpTest) add_simple_unittest(CropOpTest) + add_simple_unittest(SwitchOpTest) endif() add_simple_unittest(Im2ColTest) diff --git a/paddle/function/SwitchOp.cpp b/paddle/function/SwitchOp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..01e252a8dc0cd5fa1e964efa01d04cf282b3dfe7 --- /dev/null +++ b/paddle/function/SwitchOp.cpp @@ -0,0 +1,140 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SwitchOp.h" +#include "paddle/math/Vector.h" + +namespace paddle { + +template <> +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW, + const int argType) { + for (int n = 0; n < num; ++n) { + for (int c = 0; c < inC; ++c) { + for (int h = 0; h < inH; ++h) { + for (int w = 0; w < inW; ++w) { + if (argType == ADD_TO) { + outputs[((n * inH + h) * inW + w) * inC + c] += *(inputs++); + } else { + outputs[((n * inH + h) * inW + w) * inC + c] = *(inputs++); + } + } + } + } + } +} + +template <> +void NHWC2NCHW(real* outputs, + const real* inputs, + const int num, + const int inH, + const int inW, + const int inC, + const int argType) { + for (int n = 0; n < num; ++n) { + for (int h = 0; h < inH; ++h) { + for (int w = 0; w < inW; ++w) { + for (int c = 0; c < inC; ++c) { + if (argType == ADD_TO) { + outputs[((n * inC + c) * inH + h) * inW + w] += *(inputs++); + } else { + outputs[((n * inC + c) * inH + h) * inW + w] = *(inputs++); + } + } + } + } + } +} + +/** + * \brief Switch dimension order of image input. + * The input and output is a 4D tensor. Switch order + * 'batch_size,channels, height, width' to + * order 'batch_size, height, width, channels'. + * + * Argument in this Function: + * \param inputs input data with order 'batch_size,channels, height, width'. + * \param outputs output data with order 'batch_size, height, width, channels'. + */ +template +class NCHW2NHWCFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + + size_t num = inputs[0].shape()[0]; + size_t inC = inputs[0].shape()[1]; + size_t inH = inputs[0].shape()[2]; + size_t inW = inputs[0].shape()[3]; + NCHW2NHWC(outputs[0].data(), + inputs[0].data(), + num, + inC, + inH, + inW, + outputs[0].getArgType()); + } +}; + +/** + * \brief Switch dimension order of image input. + * The input and output is a 4D tensor. Switch order + * 'batch_size, height, width, channels' to + * order 'batch_size, channels, height, width'. + * + * Argument in this Function: + * \param inputs input data with order 'batch_size, height, width, channels'. + * \param outputs output data with order 'batch_size, channels, height, width'. + */ +template +class NHWC2NCHWFunc : public FunctionBase { +public: + void init(const FuncConfig& config) override {} + + void calc(const BufferArgs& inputs, const BufferArgs& outputs) override { + CHECK_EQ(1UL, inputs.size()); + CHECK_EQ(1UL, outputs.size()); + + size_t num = inputs[0].shape()[0]; + size_t inH = inputs[0].shape()[1]; + size_t inW = inputs[0].shape()[2]; + size_t inC = inputs[0].shape()[3]; + + NHWC2NCHW(outputs[0].data(), + inputs[0].data(), + num, + inH, + inW, + inC, + outputs[0].getArgType()); + } +}; + +REGISTER_TYPED_FUNC(NCHW2NHWC, CPU, NCHW2NHWCFunc); +REGISTER_TYPED_FUNC(NHWC2NCHW, CPU, NHWC2NCHWFunc); +#ifndef PADDLE_ONLY_CPU +REGISTER_TYPED_FUNC(NCHW2NHWC, GPU, NCHW2NHWCFunc); +REGISTER_TYPED_FUNC(NHWC2NCHW, GPU, NHWC2NCHWFunc); +#endif + +} // namespace paddle diff --git a/paddle/function/SwitchOp.h b/paddle/function/SwitchOp.h new file mode 100644 index 0000000000000000000000000000000000000000..e4c1c3ac922f88c3e5424b5943082810aabfacdb --- /dev/null +++ b/paddle/function/SwitchOp.h @@ -0,0 +1,66 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Function.h" + +namespace paddle { + +/** + * \brief This funtion switch dimension order of image input. + * The input and output is a 4D tensor. Switch order 'batch_size, + *channels, height, width' to + * order 'batch_size, height, width, channels'. + * + * \param[out] outputs save results. + * \param[in] inputs input data. + * \param[in] num batch size of input data. + * \param[in] inC channel number of input data. + * \param[in] inH height of input data. + * \param[in] inH with of input data. + * \param[in] argType type of output argument. + */ +template +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW, + const int argtype); + +/** + * \brief This funtion switch dimension order of image input. + * The input and output is a 4D tensor. Switch order 'batch_size, + *height, width, channels' to + * order 'batch_size, channels, height, width'. + * + * \param[out] inGrad gradients of previous layer. + * \param[in] outGrad output gradients. + * \param[in] num batch size of input data. + * \param[in] inH height of input data. + * \param[in] inW with of input data. + * \param[in] inC channel number of input data. + * \param[in] argType type of output argument. + */ +template +void NHWC2NCHW(real* inGrad, + const real* outGrad, + const int num, + const int inH, + const int inW, + const int inC, + const int argType); +} // namespace paddle diff --git a/paddle/function/SwitchOpGpu.cu b/paddle/function/SwitchOpGpu.cu new file mode 100644 index 0000000000000000000000000000000000000000..45390a56c3f776ec18a65a6ba2f7149a7a6ef6c3 --- /dev/null +++ b/paddle/function/SwitchOpGpu.cu @@ -0,0 +1,98 @@ +/* Copyright (c) 2016 Paddle + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SwitchOp.h" +#include "hl_base.h" + +namespace paddle { + +__global__ void KeNCHW2NHWC(real* outputs, + const real* inputs, + int inC, + int inH, + int inW, + int nthreads, + int argType) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int w = idx % inW; + const int h = (idx / inW) % inH; + const int c = (idx / inW / inH) % inC; + const int n = idx / inW / inH / inC; + + const int off = ((n * inH + h) * inW + w) * inC + c; + if (argType == ADD_TO) { + outputs[off] += inputs[idx]; + } else { + outputs[off] = inputs[idx]; + } + } +} + +template <> +void NCHW2NHWC(real* outputs, + const real* inputs, + const int num, + const int inC, + const int inH, + const int inW, + const int argType) { + size_t nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + 1024 - 1) / 1024; + KeNCHW2NHWC<<>>( + outputs, inputs, inC, inH, inW, nth, argType); + CHECK_SYNC("NCHW2NHWC"); +} + +__global__ void KeNHWC2NCHW(real* outputs, + const real* inputs, + int inH, + int inW, + int inC, + int nthreads, + int argType) { + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx < nthreads) { + const int c = idx % inC; + const int w = (idx / inC) % inW; + const int h = (idx / inC / inW) % inH; + const int n = idx / inW / inH / inC; + + const int off = ((n * inC + c) * inH + h) * inW + w; + if (argType == ADD_TO) { + outputs[off] += inputs[idx]; + } else { + outputs[off] = inputs[idx]; + } + } +} + +template <> +void NHWC2NCHW(real* outputs, + const real* inputs, + const int num, + const int inH, + const int inW, + const int inC, + const int argType) { + int nth = num * inC * inH * inW; + int blockSize = 1024; + int gridSize = (nth + 1024 - 1) / 1024; + KeNHWC2NCHW<<>>( + outputs, inputs, inH, inW, inC, nth, argType); + CHECK_SYNC("NHWC2NCHW"); +} + +} // namespace paddle diff --git a/paddle/function/SwitchOpTest.cpp b/paddle/function/SwitchOpTest.cpp new file mode 100644 index 0000000000000000000000000000000000000000..03b0dd66ddcbab713969ed747601ecb1b2eb7955 --- /dev/null +++ b/paddle/function/SwitchOpTest.cpp @@ -0,0 +1,44 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include "FunctionTest.h" + +namespace paddle { + +TEST(Pad, real) { + for (size_t numSamples : {1, 4, 8, 16}) { + for (size_t channels : {1, 4, 8, 16}) { + for (size_t imgSizeH : {1, 4, 8, 16}) { + for (size_t imgSizeW : {1, 4, 8, 16}) { + VLOG(3) << " numSamples=" << numSamples << " channels=" << channels + << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; + for (bool test_grad : {true, false}) { + CpuGpuFuncCompare compare(test_grad ? "NHWC2NCHW" : "NCHW2NHWC", + FuncConfig()); + TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; + TensorShape outDims{numSamples, imgSizeH, imgSizeW, channels}; + compare.addInputs( + BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims)); + compare.addOutputs(BufferArg( + VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO)); + compare.run(); + } + } + } + } + } +} + +} // namespace paddle diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a91042f628920a9986763531fb4c633307b43b8 --- /dev/null +++ b/paddle/gserver/layers/SwitchOrderLayer.cpp @@ -0,0 +1,107 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "SwitchOrderLayer.h" +#include "paddle/utils/Stat.h" + +namespace paddle { + +REGISTER_LAYER(switch_order, SwitchOrderLayer); + +bool SwitchOrderLayer::init(const LayerMap& layerMap, + const ParameterMap& parameterMap) { + /* Initialize the basic parent class */ + Layer::init(layerMap, parameterMap); + auto& img_conf = config_.inputs(0).image_conf(); + size_t inH = + img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size(); + size_t inW = img_conf.img_size(); + size_t inC = img_conf.channels(); + inDims_ = TensorShape({0, inC, inH, inW}); + outDims_ = TensorShape(4); + + auto& reshape_conf = config_.reshape_conf(); + for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) { + heightAxis_.push_back(reshape_conf.heightaxis(i)); + } + for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) { + widthAxis_.push_back(reshape_conf.widthaxis(i)); + } + createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig()); + createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig()); + return true; +} + +void SwitchOrderLayer::setOutDims() { + outDims_.setDim(0, inDims_[0]); + outDims_.setDim(1, inDims_[2]); + outDims_.setDim(2, inDims_[3]); + outDims_.setDim(3, inDims_[1]); + reshapeHeight_ = 1; + for (size_t i = 0; i < heightAxis_.size(); i++) { + reshapeHeight_ *= outDims_[heightAxis_[i]]; + } + output_.setFrameHeight(reshapeHeight_); + reshapeWidth_ = 1; + for (size_t i = 0; i < widthAxis_.size(); i++) { + reshapeWidth_ *= outDims_[widthAxis_[i]]; + } + output_.setFrameWidth(reshapeWidth_); +} + +void SwitchOrderLayer::setInDims() { + MatrixPtr input = inputLayers_[0]->getOutputValue(); + size_t batchSize = input->getHeight(); + inDims_.setDim(0, batchSize); + + int h = inputLayers_[0]->getOutput().getFrameHeight(); + if (h != 0) inDims_.setDim(2, h); + int w = inputLayers_[0]->getOutput().getFrameWidth(); + if (w != 0) inDims_.setDim(3, w); + int totalCount = input->getElementCnt(); + int channels = totalCount / (inDims_[0] * inDims_[2] * inDims_[3]); + if (channels != 0) inDims_.setDim(1, channels); +} + +void SwitchOrderLayer::forward(PassType passType) { + Layer::forward(passType); + setInDims(); + setOutDims(); + resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]); + if (heightAxis_.size() > 0) { + getOutputValue()->reshape(reshapeHeight_, reshapeWidth_); + getOutputGrad()->reshape(reshapeHeight_, reshapeWidth_); + } + + // switch NCHW to NHWC + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getInputValue(0), inDims_); + outputs.addArg(*getOutputValue(), outDims_); + nchw2nhwc_[0]->calc(inputs, outputs); + forwardActivation(); +} + +void SwitchOrderLayer::backward(const UpdateCallback& callback) { + (void)callback; + backwardActivation(); + + // switch NHWC to NCHW + BufferArgs inputs; + BufferArgs outputs; + inputs.addArg(*getOutputGrad(), outDims_); + outputs.addArg(*getInputGrad(0), inDims_, ADD_TO); + nhwc2nchw_[0]->calc(inputs, outputs); +} +} // namespace paddle diff --git a/paddle/gserver/layers/SwitchOrderLayer.h b/paddle/gserver/layers/SwitchOrderLayer.h new file mode 100644 index 0000000000000000000000000000000000000000..47b1f7f73ee783b3eae3c9cfe08b1459cef16a71 --- /dev/null +++ b/paddle/gserver/layers/SwitchOrderLayer.h @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "Layer.h" + +namespace paddle { + +/** + * \brief This layer calculate softmax in image channel dimension. + */ +class SwitchOrderLayer : public Layer { +public: + explicit SwitchOrderLayer(const LayerConfig& config) : Layer(config) {} + + ~SwitchOrderLayer() {} + + bool init(const LayerMap& layerMap, + const ParameterMap& parameterMap) override; + void forward(PassType passType) override; + void backward(const UpdateCallback& callback = nullptr) override; + void setInDims(); + void setOutDims(); + +protected: + std::vector> nchw2nhwc_; + std::vector> nhwc2nchw_; + TensorShape inDims_; + TensorShape outDims_; + std::vector heightAxis_; + std::vector widthAxis_; + size_t reshapeHeight_; + size_t reshapeWidth_; +}; +} // namespace paddle diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index a831ffbc73fbd6ad42fa31b2d6d583718474e59b..e0c14ad5b512c7329062a5426ef34844ec268020 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2008,6 +2008,31 @@ TEST(Layer, CropLayer) { } } +TEST(Layer, SwitchOrderLayer) { + TestConfig config; + // config input_0 + config.inputDefs.push_back({INPUT_DATA, "layer_0", 1024, 0}); + LayerInputConfig* input = config.layerConfig.add_inputs(); + ImageConfig* img = input->mutable_image_conf(); + img->set_channels(4); + img->set_img_size(16); + img->set_img_size_y(16); + + ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf(); + reshape->add_heightaxis(0); + reshape->add_heightaxis(1); + reshape->add_heightaxis(2); + reshape->add_widthaxis(3); + + // config softmax layer + config.layerConfig.set_type("switch_order"); + config.layerConfig.set_name("switchOrderLayer"); + + for (auto useGpu : {false, true}) { + testLayerGrad(config, "switch_order", 100, false, useGpu, true); + } +} + vector randSampling(real range, int n) { CHECK_GE(range, n); vector num(range); diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 431d4e071072317c8fdfdc4f0d13e7cd4e3d062b..44180bca8bca53e74d71ce7bed3516399c01c81d 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -1616,6 +1616,10 @@ public: }; class CpuMatrix : public Matrix { +private: + MatrixPtr sftmaxSum_; + MatrixPtr sftmaxDot_; + public: CpuMatrix(size_t height, size_t width, bool trans = false); CpuMatrix(real* data, size_t height, size_t width, bool trans = false) diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 4ddf023780c704cb10c51ee9e5d7cb63420f9d73..0f44d8cb8d78ed23cc1105ac7aff37de5faeffa1 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -287,6 +287,11 @@ message PadConfig { repeated uint32 pad_w = 4; } +message ReshapeConfig { + repeated uint32 heightAxis = 1; + repeated uint32 widthAxis = 2; +} + message MultiBoxLossConfig { required uint32 num_classes = 1; required float overlap_threshold = 2; @@ -339,7 +344,6 @@ message LayerInputConfig { } message LayerConfig { - required string name = 1; required string type = 2; optional uint64 size = 3; @@ -516,6 +520,9 @@ message LayerConfig { optional double delta = 57 [ default = 1.0 ]; optional uint64 depth = 58 [ default = 1 ]; + + // for switch order layer + optional ReshapeConfig reshape_conf = 59; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index 152a56190c1ffddbf9590ed8f71308ceb88403f4..11dc84ae20679bb73735f9119739fca5ea7fa673 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -3670,6 +3670,15 @@ class RecurrentLayerGroup(LayerBase): name, 'recurrent_layer_group', 0, inputs=[], device=device) +@config_layer('switch_order') +class SwitchOrderLayer(LayerBase): + def __init__(self, name, inputs, reshape, **xargs): + super(SwitchOrderLayer, self).__init__( + name, 'switch_order', 0, inputs=inputs, **xargs) + self.config.reshape_conf.heightAxis.extend(reshape['height']) + self.config.reshape_conf.widthAxis.extend(reshape['width']) + + # Deprecated, use a new layer specific class instead @config_func def Layer(name, type, **xargs): diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 47ac601e678013aceb62005d6f25595f49673d2c..cba45bd3afa178ab4dd3a50f0947b144e7466e53 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -131,6 +131,7 @@ __all__ = [ 'row_conv_layer', 'dropout_layer', 'prelu_layer', + 'switch_order_layer', 'gated_unit_layer', 'crop_layer', 'sub_nested_seq_layer', @@ -239,6 +240,7 @@ class LayerType(object): SMOOTH_L1 = 'smooth_l1' PRELU = 'prelu' + SWITCH_ORDER_LAYER = 'switch_order' CROP_LAYER = 'crop' SUB_NESTED_SEQ = 'sub_nested_seq' CLIP_LAYER = 'clip' @@ -6404,6 +6406,48 @@ def gated_unit_layer(input, layer_attr=layer_attr) +@layer_support() +@wrap_name_default('switch_order') +def switch_order_layer(input, + name=None, + reshape=None, + act=None, + layer_attr=None): + """ + This layer switch dimension order of image input. + From order "batchSize, channels, height, width" + to order "batchSize, height, width, channels". + + The example usage is: + + .. code-block:: python + reshape = {'height':[ 0, 1, 2], 'width':[3]} + switch = switch_order(input=layer, name='switch', reshape=reshape) + + :param input: The input layer. + :type input: LayerOutput + :param name: Name of this layer. + :type name: basestring + :param reshape: reshape matrix by axises. + :type reshape: Dict + :return: LayerOutput object. + :rtype: LayerOutput + """ + assert isinstance(input, LayerOutput) + l = Layer( + name=name, + inputs=input.name, + reshape=reshape, + type=LayerType.SWITCH_ORDER_LAYER, + active_type=act.name, + **ExtraLayerAttribute.to_kwargs(layer_attr)) + return LayerOutput( + name=name, + layer_type=LayerType.SWITCH_ORDER_LAYER, + parents=input, + size=l.config.size) + + @wrap_name_default() @layer_support() def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):