From 2dd55b873fcad8fb7e06963d6ea08ba17e7ce1b7 Mon Sep 17 00:00:00 2001 From: shippingwang Date: Mon, 17 Dec 2018 13:08:02 +0000 Subject: [PATCH] Add shuffle_channel_op --- paddle/fluid/operators/shuffle_channel_op.cc | 126 +++++++++++ paddle/fluid/operators/shuffle_channel_op.cu | 24 ++ paddle/fluid/operators/shuffle_channel_op.h | 101 +++++++++ python/paddle/fluid/layers/nn.py | 213 ++++++------------ .../fluid/tests/unittests/test_layers.py | 9 + .../unittests/test_shuffle_channel_op.py | 54 +++++ 6 files changed, 385 insertions(+), 142 deletions(-) create mode 100644 paddle/fluid/operators/shuffle_channel_op.cc create mode 100644 paddle/fluid/operators/shuffle_channel_op.cu create mode 100644 paddle/fluid/operators/shuffle_channel_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc new file mode 100644 index 00000000000..ec1255af168 --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -0,0 +1,126 @@ +/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/shuffle_channel_op.h" + +namespace paddle { +namespace operators { + +class ShuffleChannelOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx - > HasInput("X"), + "Input(X) of ShuffleChannelOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Out"), + "Output(Out) of ShuffleChannelOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); + + // ENFORCE group + auto group = ctx->Attrs().Get>("group"); + ctx->SetOutputDim("Out", input_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.GetPlace()); + } +}; + +class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor, default Tensor), " + "the input feature data of ShuffleChannelOp, the layout is NCHW."); + AddOutput("Out", + "(Tensor, default Tensor), the output of " + "ShuffleChannelOp. The layout is NCHW."); + AddAttr("group", "the number of groups.") + .SetDefault(1) + .AddCustomChecker([](const int& group) { + PADDLE_ENFORCE_GE(group, 1, "group should be larger than 0."); + }); + + AddComment(R"DOC( + Shuffle Channel operator + This operator obtains the group convolutional layer with channels shuffled. + First, divide the input channels in each group into several subgroups, + then, feed each group in the next layer with different subgroups. + + According to the paper, "Suppose a convolution layer with g groups + whose output has g x n channels, first reshape the output channel dimension into(g,n), + transposing and then flattening it back as the input of next layer. " + + Shuffle channel operation makes it possible to build more powerful structures + with multiple group convolutional layers. + + please get more information from the following paper: + https://arxiv.org/pdf/1707.01083.pdf + )DOC"); + } +}; + +// Grad + +class ShuffleChannelOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@Grad) should not be null") + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Output(X@Grad) should not be null"); + + auto input_dims = ctx->GetInputDim("X"); + ctx->SetOutputDim(framework::GradVarName("X"), input_dims); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + ctx.Input(framework::GradVarName("Out")) + ->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +// how to write gpu kernal +namespace ops = paddle::operators; +REGISTER_OPERATOR(shufflechannel, ops::ShuffleChannelOp, + ops::ShuffleChannelOpMaker, + paddle::framework::DefaultGradOpDescMaker); +// paddle::framework::EmptyGradOpMaker); + +REGISTER_OPERATOR(shufflechannel_grad, ops::ShuffleChannelGradOp); + +REGISTER_OP_CPU_KERNEL( + shufflechannel, + ops::ShuffleChannelOpKernel, + ops::ShuffleChannelOpKernel); + +REGISTER_OP_CPU_KERNEL( + shufflechannel_grad, + ops::ShuffleChannelGradOpKernel, + ops::ShuffleChannelGradOpKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.cu b/paddle/fluid/operators/shuffle_channel_op.cu new file mode 100644 index 00000000000..b1eacd0cbe4 --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/shuffle_channel_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + shufflechannel, + ops::ShuffleChannelOpKernel + ops::ShuffleChannelOpKernel); +REGISTER_OP_CUDA_KERNEL( + shufflechannel_grad, + ops::ShuffleChannelOpGradKernel + ops::ShuffleChannelOpGradKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.h b/paddle/fluid/operators/shuffle_channel_op.h new file mode 100644 index 00000000000..f923babf5b8 --- /dev/null +++ b/paddle/fluid/operators/shuffle_channel_op.h @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class ShuffleChannelOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + auto group = ctx.Input("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + + int group_row = group; + int group_column = channels / group_row; + + const T* input_data = input->data(); + T* output_data = out->mutable_data(ctx.GetPlace()); + + for (int n = 0; n < num; ++n) { + output_data_temp = output_data + n * feature_map_size; + input_data_temp = input_data + n * feature_map_size; + for (int i = 0; i < group_row; ++i) { + for (int j = 0; j < group_column; ++j) { + const auto* p_i = input_data_temp + (i * group_column + j) * sp_sz; + auto* p_o = output_data_temp + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); + } + } + } + return; + } +}; + +template +class ShuffleChannelGradOpKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto group = ctx.Input("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + + int group_row = group; + int group_column = channels / group_row; + + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + const T* output_grad_data = output_grad->data(); + + for (int n = 0; n < num; ++n) { + output_grad_temp = output_grad_data + n * feature_map_size; + input_grad_temp = input_grad_data + n * feature_map_size; + for (int i = 0; i < group_row; ++i) { + for (int j = 0; j < group_column; ++j) { + const auto* p_i = output_grad_temp + (i * group_column + j) * sp_sz; + auto* p_o = input_grad_temp + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); + } + } + } + return; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e25eaaa9fda..5e1b6c999bc 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -31,148 +31,37 @@ from functools import reduce from .. import core __all__ = [ - 'fc', - 'embedding', - 'dynamic_lstm', - 'dynamic_lstmp', - 'dynamic_gru', - 'gru_unit', - 'linear_chain_crf', - 'crf_decoding', - 'cos_sim', - 'cross_entropy', - 'bpr_loss', - 'square_error_cost', - 'chunk_eval', - 'sequence_conv', - 'conv2d', - 'conv3d', - 'sequence_pool', - 'sequence_softmax', - 'softmax', - 'pool2d', - 'pool3d', - 'batch_norm', - 'beam_search_decode', - 'conv2d_transpose', - 'conv3d_transpose', - 'sequence_expand', - 'sequence_expand_as', - 'sequence_pad', - 'sequence_unpad', - 'lstm_unit', - 'reduce_sum', - 'reduce_mean', - 'reduce_max', - 'reduce_min', - 'reduce_prod', - 'sequence_first_step', - 'sequence_last_step', - 'sequence_slice', - 'dropout', - 'split', - 'ctc_greedy_decoder', - 'edit_distance', - 'l2_normalize', - 'matmul', - 'topk', - 'warpctc', - 'sequence_reshape', - 'transpose', - 'im2sequence', - 'nce', - 'hsigmoid', - 'beam_search', - 'row_conv', - 'multiplex', - 'layer_norm', - 'group_norm', - 'softmax_with_cross_entropy', - 'smooth_l1', - 'one_hot', - 'autoincreased_step_counter', - 'reshape', - 'squeeze', - 'unsqueeze', - 'lod_reset', - 'lrn', - 'pad', - 'pad_constant_like', - 'label_smooth', - 'roi_pool', - 'roi_align', - 'dice_loss', - 'image_resize', - 'image_resize_short', - 'resize_bilinear', - 'resize_nearest', - 'gather', - 'scatter', - 'sequence_scatter', - 'random_crop', - 'mean_iou', - 'relu', - 'selu', - 'log', - 'crop', - 'rank_loss', - 'margin_rank_loss', - 'elu', - 'relu6', - 'pow', - 'stanh', - 'hard_sigmoid', - 'swish', - 'prelu', - 'brelu', - 'leaky_relu', - 'soft_relu', - 'flatten', - 'sequence_mask', - 'stack', - 'pad2d', - 'unstack', - 'sequence_enumerate', - 'expand', - 'sequence_concat', - 'scale', - 'elementwise_add', - 'elementwise_div', - 'elementwise_sub', - 'elementwise_mul', - 'elementwise_max', - 'elementwise_min', - 'elementwise_pow', - 'uniform_random_batch_size_like', - 'gaussian_random', - 'sampling_id', - 'gaussian_random_batch_size_like', - 'sum', - 'slice', - 'shape', - 'logical_and', - 'logical_or', - 'logical_xor', - 'logical_not', - 'clip', - 'clip_by_norm', - 'mean', - 'mul', - 'sigmoid_cross_entropy_with_logits', - 'maxout', - 'space_to_depth', - 'affine_grid', - 'sequence_reverse', - 'affine_channel', - 'similarity_focus', - 'hash', - 'grid_sampler', - 'log_loss', - 'add_position_encoding', - 'bilinear_tensor_product', - 'merge_selected_rows', - 'get_tensor_from_selected_rows', - 'lstm', + 'fc', 'embedding', 'dynamic_lstm', 'dynamic_lstmp', 'dynamic_gru', + 'gru_unit', 'linear_chain_crf', 'crf_decoding', 'cos_sim', 'cross_entropy', + 'bpr_loss', 'square_error_cost', 'chunk_eval', 'sequence_conv', 'conv2d', + 'conv3d', 'sequence_pool', 'sequence_softmax', 'softmax', 'pool2d', + 'pool3d', 'batch_norm', 'beam_search_decode', 'conv2d_transpose', + 'conv3d_transpose', 'sequence_expand', 'sequence_expand_as', 'sequence_pad', + 'sequence_unpad', 'lstm_unit', 'reduce_sum', 'reduce_mean', 'reduce_max', + 'reduce_min', 'reduce_prod', 'sequence_first_step', 'sequence_last_step', + 'sequence_slice', 'dropout', 'split', 'ctc_greedy_decoder', 'edit_distance', + 'l2_normalize', 'matmul', 'topk', 'warpctc', 'sequence_reshape', + 'transpose', 'im2sequence', 'nce', 'hsigmoid', 'beam_search', 'row_conv', + 'multiplex', 'layer_norm', 'group_norm', 'softmax_with_cross_entropy', + 'smooth_l1', 'one_hot', 'autoincreased_step_counter', 'reshape', 'squeeze', + 'unsqueeze', 'lod_reset', 'lrn', 'pad', 'pad_constant_like', 'label_smooth', + 'roi_pool', 'roi_align', 'dice_loss', 'image_resize', 'image_resize_short', + 'resize_bilinear', 'resize_nearest', 'gather', 'scatter', + 'sequence_scatter', 'random_crop', 'mean_iou', 'relu', 'selu', 'log', + 'crop', 'rank_loss', 'margin_rank_loss', 'elu', 'relu6', 'pow', 'stanh', + 'hard_sigmoid', 'swish', 'prelu', 'brelu', 'leaky_relu', 'soft_relu', + 'flatten', 'sequence_mask', 'stack', 'pad2d', 'unstack', + 'sequence_enumerate', 'expand', 'sequence_concat', 'scale', + 'elementwise_add', 'elementwise_div', 'elementwise_sub', 'elementwise_mul', + 'elementwise_max', 'elementwise_min', 'elementwise_pow', + 'uniform_random_batch_size_like', 'gaussian_random', 'sampling_id', + 'gaussian_random_batch_size_like', 'sum', 'slice', 'shape', 'logical_and', + 'logical_or', 'logical_xor', 'logical_not', 'clip', 'clip_by_norm', 'mean', + 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', 'space_to_depth', + 'affine_grid', 'sequence_reverse', 'affine_channel', 'similarity_focus', + 'hash', 'grid_sampler', 'log_loss', 'add_position_encoding', + 'bilinear_tensor_product', 'merge_selected_rows', + 'get_tensor_from_selected_rows', 'lstm', 'shufflechannel' ] kIgnoreIndex = -100 @@ -9122,3 +9011,43 @@ def get_tensor_from_selected_rows(x, name=None): outputs={'Out': out}, attrs={}) return out + + +def shuffle_channel(x, group=1, name=None): + """ + **Shuffle Channel Operator** + This operator obtains the group convolutional layer with channels shuffled. + First, divide the input channels in each group into several subgroups, + then, feed each group in the next layer with different subgroups. + Shuffle channel operation makes it possible to build more powerful structures + with multiple group convolutional layers. + + Args: + x: The input tensor variable. + + + Returns: + Variable: channel shuffled tensor variable. + + Raises: + ValueError: If group in not a int type variable. + + Examples: + .. code-block:: python + + + """ + helper = LayerHelper("shuffle_channel", **locals()) + + out = helper.create_variable_for_type_inference( + dtype=helper.intput_dtype('x')) + + if not isinstance(group, int): + raise TypeError("group must be int type") + + helper.append_op( + type="shuffle_channel", + inputs={"X": x}, + outputs={"Out": out}, + attrs={"group": group}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 10e8bb5a866..155f59f6fea 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -982,6 +982,15 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_shuffle_channel(self): + program = Program() + with program_guard(program): + x = layers.data(name="x", shape=[10, 32, 16, 16], dtype="float32") + group = layers.data(name="group", shape=[1], dtype="int32") + out = layers.shuffle_channel(x, group) + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py new file mode 100644 index 00000000000..25df22193ca --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py @@ -0,0 +1,54 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +import math +from op_test import OpTest +import paddle.fluid.core as core + + +class TestShuffleChannelOp(OpTest): + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'output') + + def setUp(self): + self.op_type = "shuffle_channel" + self.batch_size = 10 + self.input_channels = 16 + self.layer_h = 32 + self.layer_w = 32 + self.group = 4 + + self.x = np.random.random( + (self.batch_size, self.input_channels, self.layer_h, self, + layer_w)).astype('float32') + self.inputs = {'X': self.x} + self.attrs = {'group': self.group} + + n, c, h, w = self.x.shape + input_reshaped = np.reshape(self.x, + (-1, self.group, c // self.group, h, w)) + input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4)) + self.outputs = np.reshape(input_transposed, (-1, c, h, w)) + + +if __name__ == '__main__': + unittest.main() -- GitLab