From 5a9ea9a73d51841790940ffb36790d8424adacba Mon Sep 17 00:00:00 2001 From: shippingwang Date: Tue, 25 Dec 2018 02:25:45 +0000 Subject: [PATCH] Add ShuffleChannel Op --- paddle/fluid/operators/shuffle_channel_op.cc | 6 +++++- paddle/fluid/operators/shuffle_channel_op.cu | 2 ++ python/paddle/fluid/layers/nn.py | 10 +++++----- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index 0ede3922ea..1ab8b42d8d 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -28,7 +28,7 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); // ENFORCE group - // auto group = ctx->Attrs().Get("group"); + ctx->SetOutputDim("Out", input_dims); } /* @@ -87,6 +87,10 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { "Output(X@Grad) should not be null"); auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + // ENFORCE group + ctx->SetOutputDim(framework::GradVarName("X"), input_dims); } /* diff --git a/paddle/fluid/operators/shuffle_channel_op.cu b/paddle/fluid/operators/shuffle_channel_op.cu index 77418ac7e3..e8badc40cd 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cu +++ b/paddle/fluid/operators/shuffle_channel_op.cu @@ -81,6 +81,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); int group = ctx.Attr("group"); + auto input_dims = input->dims(); auto num = input_dims[0]; auto channel = input_dims[1]; @@ -101,6 +102,7 @@ class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { int blocks = NumBlocks(output_grad->numel()); int threads = kNumCUDAThreads; int count = num * group_column * group_row * sp_sz; + ShuffleChannel< T><<>>( count, feature_map_size, input_grad_data, output_grad_data, group_row, diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 3e3eea084e..e654047df6 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -55,6 +55,8 @@ __all__ = [ 'softmax', 'pool2d', 'pool3d', + 'adaptive_pool2d', + 'adaptive_pool3d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', @@ -9342,24 +9344,22 @@ def shuffle_channel(x, group=1, name=None): x: The input tensor variable.. group: The num of group - Returns: - Variable: channel shuffled tensor variable. + Variable: channels shuffled tensor variable. Raises: - ValueError: If group in not an int type variable. + ValueError: If group is not an int type variable. Examples: .. code-block:: python out = fluid.layers.shuffle_channel(x=group_conv,group=4) - """ helper = LayerHelper("shuffle_channel", **locals()) out = helper.create_variable_for_type_inference( - dtype=helper.intput_dtype('x')) + dtype=helper.input_dtype('x')) if not isinstance(group, int): raise TypeError("group must be int type") -- GitLab