diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 02d68b5ee0483e397e014776be8cc80b35edc198..afd3342768701adba4ff0040bd1c762b1cd8739d 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -213,6 +213,7 @@ paddle.fluid.layers.bilinear_tensor_product ArgSpec(args=['x', 'y', 'size', 'act paddle.fluid.layers.merge_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.get_tensor_from_selected_rows ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.lstm ArgSpec(args=['input', 'init_h', 'init_c', 'max_len', 'hidden_size', 'num_layers', 'dropout_prob', 'is_bidirec', 'is_test', 'name', 'default_initializer', 'seed'], varargs=None, keywords=None, defaults=(0.0, False, False, None, None, -1)) +paddle.fluid.layers.shuffle_channel ArgSpec(args=['x', 'group', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.py_func ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.psroi_pool ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.teacher_student_sigmoid_loss ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)) diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..9349912e090f2ad3248923c87b50c8d72b0d84d1 --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -0,0 +1,113 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/shuffle_channel_op.h" + +namespace paddle { +namespace operators { + +class ShuffleChannelOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of ShuffleChannelOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of ShuffleChannelOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + ctx->SetOutputDim("Out", input_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of ShuffleChannelOp, the layout is NCHW."); + AddOutput("Out", + "(Tensor, default Tensor), the output of " + "ShuffleChannelOp. The layout is NCHW."); + AddAttr("group", "the number of groups.") + .SetDefault(1) + .AddCustomChecker([](const int& group) { + PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0."); + }); + + AddComment(R"DOC( + Shuffle Channel operator + This opearator shuffles the channels of input x. + It divide the input channels in each group into several subgroups, + and obtain a new order by selecting element from every subgroup one by one. + + Shuffle channel operation makes it possible to build more powerful structures + with multiple group convolutional layers. + please get more information from the following paper: + https://arxiv.org/pdf/1707.01083.pdf + )DOC"); + } +}; + +class ShuffleChannelGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) should not be null"); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + ctx->SetOutputDim(framework::GradVarName("X"), input_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp, + ops::ShuffleChannelOpMaker, + paddle::framework::DefaultGradOpDescMaker); + +REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp); + +REGISTER_OP_CPU_KERNEL( + shuffle_channel, + ops::ShuffleChannelOpKernel, + ops::ShuffleChannelOpKernel); + +REGISTER_OP_CPU_KERNEL( + shuffle_channel_grad, + ops::ShuffleChannelGradOpKernel, + ops::ShuffleChannelGradOpKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.cu b/paddle/fluid/operators/shuffle_channel_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9506343b3d508459c6e10dc68eba13504b07338f --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.cu @@ -0,0 +1,125 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/shuffle_channel_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template +__global__ void ShuffleChannel(const int nthreads, const int feature_map_size, + T* output, const T* input, int group_row, + int group_column, int len) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t ii = index; ii < nthreads; ii += offset) { + const int n = index / group_row / group_column / len; + const int i = (index / group_column / len) % group_row; + const int j = index / len % group_column; + const int k = index - (n * feature_map_size + (i * group_column + j) * len); + T* p_o = output + n * feature_map_size + (j * group_row + i) * len; + p_o[k] = input[index]; + } +} +template +class ShuffleChannelOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int group = ctx.Attr("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + int group_row = group; + int group_column = channel / group_row; + // count is the product of NCHW same as numel() + int count = num * group_column * group_row * sp_sz; + + int blocks = NumBlocks(output->numel()); + int threads = kNumCUDAThreads; + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + + ShuffleChannel< + T><<>>( + count, feature_map_size, output_data, input_data, group_row, + group_column, sp_sz); + } +}; + +template +class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + int group = ctx.Attr("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + + int group_row = group; + int group_column = channel / group_row; + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + const T* output_grad_data = output_grad->data(); + + int blocks = NumBlocks(output_grad->numel()); + int threads = kNumCUDAThreads; + int count = num * group_column * group_row * sp_sz; + + ShuffleChannel< + T><<>>( + count, feature_map_size, input_grad_data, output_grad_data, group_row, + group_column, sp_sz); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + shuffle_channel, + ops::ShuffleChannelOpCUDAKernel, + ops::ShuffleChannelOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL( + shuffle_channel_grad, + ops::ShuffleChannelGradOpCUDAKernel, + ops::ShuffleChannelGradOpCUDAKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.h b/paddle/fluid/operators/shuffle_channel_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f6af1bc88598870ebccef81bd37f93f376940851 --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.h @@ -0,0 +1,95 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class ShuffleChannelOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int group = ctx.Attr("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + int group_row = group; + int group_column = channel / group_row; + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + for (int n = 0; n < num; ++n) { + for (int i = 0; i < group_row; ++i) { + for (int j = 0; j < group_column; ++j) { + const T* p_i = input_data + n * feature_map_size + + (i * group_column + j) * sp_sz; + T* p_o = + output_data + n * feature_map_size + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(int) * sp_sz); + } + } + } + } +}; + +template +class ShuffleChannelGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + int group = ctx.Attr("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + + int group_row = group; + int group_column = channel / group_row; + + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + const T* output_grad_data = output_grad->data(); + for (int n = 0; n < num; ++n) { + for (int i = 0; i < group_row; ++i) { + for (int j = 0; j < group_column; ++j) { + const T* p_i = output_grad_data + n * feature_map_size + + (i * group_column + j) * sp_sz; + T* p_o = input_grad_data + n * feature_map_size + + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(int) * sp_sz); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 339290384398df6d85d2f914f311af2cd0d33aea..beb5e31211c5f9aa6bddfcb1da7e63d6480e99e1 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -179,6 +179,7 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'lstm', + 'shuffle_channel', 'py_func', 'psroi_pool', 'teacher_student_sigmoid_loss', @@ -9646,6 +9647,79 @@ def get_tensor_from_selected_rows(x, name=None): return out +def shuffle_channel(x, group, name=None): + """ + **Shuffle Channel Operator** + + This operator shuffles the channels of input x. + It divide the input channels in each group into :attr:`group` subgroups, + and obtain a new order by selecting element from every subgroup one by one. + + Please refer to the paper + https://arxiv.org/pdf/1707.01083.pdf + + .. code-block:: text + + Given a 4-D tensor input with the shape (N, C, H, W): + input.shape = (1, 4, 2, 2) + input.data =[[[[0.1, 0.2], + [0.2, 0.3]], + + [[0.3, 0.4], + [0.4, 0.5]], + + [[0.5, 0.6], + [0.6, 0.7]], + + [[0.7, 0.8], + [0.8, 0.9]]]] + Given group: 2 + then we get a 4-D tensor out whth the same shape of input: + out.shape = (1, 4, 2, 2) + out.data = [[[[0.1, 0.2], + [0.2, 0.3]], + + [[0.5, 0.6], + [0.6, 0.7]], + + [[0.3, 0.4], + [0.4, 0.5]], + + [[0.7, 0.8], + [0.8, 0.9]]]] + + Args: + x(Variable): The input tensor variable. It should be a 4-D tensor with shape [N, C, H, W] + group(int): Indicating the conuts of subgroups, It should divide the number of channels. + + Returns: + out(Variable): the channels shuffling result is a tensor variable with the + same shape and same type as the input. + + Raises: + ValueError: If group is not an int type variable. + + Examples: + .. code-block:: python + + input = fluid.layers.data(name='input', shape=[4,2,2], dtype='float32') + out = fluid.layers.shuffle_channel(x=input, group=2) + """ + helper = LayerHelper("shuffle_channel", **locals()) + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if not isinstance(group, int): + raise TypeError("group must be int type") + + helper.append_op( + type="shuffle_channel", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"group": group}) + return out + + class PyFuncRegistry(object): _register_funcs = [] diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 90f5d797a67d951e618e64cfc5a3608335714e05..c13f03e86f3e375026b04a31d51ac1a5223360ef 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1023,6 +1023,14 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_shuffle_channel(self): + program = Program() + with program_guard(program): + x = layers.data(name="X", shape=[16, 4, 4], dtype="float32") + out = layers.shuffle_channel(x, group=4) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py new file mode 100644 index 0000000000000000000000000000000000000000..aeaae9058187be1c9191bcbec21237c69fefe6e6 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py @@ -0,0 +1,52 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest +import paddle.fluid.core as core + + +class TestShuffleChannelOp(OpTest): + def setUp(self): + self.op_type = "shuffle_channel" + self.batch_size = 10 + self.input_channels = 16 + self.layer_h = 4 + self.layer_w = 4 + self.group = 4 + self.x = np.random.random( + (self.batch_size, self.input_channels, self.layer_h, + self.layer_w)).astype('float32') + self.inputs = {'X': self.x} + self.attrs = {'group': self.group} + n, c, h, w = self.x.shape + input_reshaped = np.reshape(self.x, + (-1, self.group, c // self.group, h, w)) + input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4)) + self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + +if __name__ == '__main__': + unittest.main()