diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index ec1255af168d112de46ef09cb2d8a6286d2e04ea..0ede3922ea4cbf39d3ea4f0704142e558705ae2e 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -19,26 +19,27 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx - > HasInput("X"), + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ShuffleChannelOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput("Out"), + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of ShuffleChannelOp should not be null."); auto input_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(input_dims.size() == 4, "The layout of input is NCHW."); // ENFORCE group - auto group = ctx->Attrs().Get>("group"); + // auto group = ctx->Attrs().Get("group"); ctx->SetOutputDim("Out", input_dims); } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.GetPlace()); - } + /* + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + */ }; class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { @@ -63,7 +64,7 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { then, feed each group in the next layer with different subgroups. According to the paper, "Suppose a convolution layer with g groups - whose output has g x n channels, first reshape the output channel dimension into(g,n), + whose output has g * n channels, first reshape the output channel dimension into(g,n), transposing and then flattening it back as the input of next layer. " Shuffle channel operation makes it possible to build more powerful structures @@ -75,52 +76,49 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { } }; -// Grad - -class ShuffleChannelOpGrad : public framework::OperatorWithKernel { +class ShuffleChannelGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Input(Out@Grad) should not be null") + "Input(Out@Grad) should not be null"); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Output(X@Grad) should not be null"); auto input_dims = ctx->GetInputDim("X"); ctx->SetOutputDim(framework::GradVarName("X"), input_dims); } - - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType( - ctx.Input(framework::GradVarName("Out")) - ->type()), - ctx.device_context()); - } + /* + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } + */ }; } // namespace operators } // namespace paddle -// how to write gpu kernal namespace ops = paddle::operators; -REGISTER_OPERATOR(shufflechannel, ops::ShuffleChannelOp, +REGISTER_OPERATOR(shuffle_channel, ops::ShuffleChannelOp, ops::ShuffleChannelOpMaker, paddle::framework::DefaultGradOpDescMaker); // paddle::framework::EmptyGradOpMaker); -REGISTER_OPERATOR(shufflechannel_grad, ops::ShuffleChannelGradOp); +REGISTER_OPERATOR(shuffle_channel_grad, ops::ShuffleChannelGradOp); REGISTER_OP_CPU_KERNEL( - shufflechannel, + shuffle_channel, ops::ShuffleChannelOpKernel, ops::ShuffleChannelOpKernel); REGISTER_OP_CPU_KERNEL( - shufflechannel_grad, + shuffle_channel_grad, ops::ShuffleChannelGradOpKernel, ops::ShuffleChannelGradOpKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.cu b/paddle/fluid/operators/shuffle_channel_op.cu index b1eacd0cbe4f1acae326b43502ff36443861e564..77418ac7e3c5f8471721208a7eb751359fd0dc62 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cu +++ b/paddle/fluid/operators/shuffle_channel_op.cu @@ -10,15 +10,115 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/shuffle_channel_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" +#include "paddle/fluid/platform/gpu_info.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaximumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaximumNumBlocks); +} + +template + +__global__ void ShuffleChannel(const int nthreads, const int feature_map_size, + T* output, const T* input, int group_row, + int group_column, int len) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (size_t ii = index; ii < nthreads; ii += offset) { + const int n = index / group_row / group_column / len; + const int i = (index / group_column / len) % group_row; + const int j = index / len % group_column; + const int k = index - (n * feature_map_size + (i * group_column + j) * len); + T* p_o = output + n * feature_map_size + (j * group_row + i) * len; + p_o[k] = input[index]; + } +} +template +class ShuffleChannelOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + auto* output = ctx.Output("Out"); + int group = ctx.Attr("group"); + + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + int group_row = group; + int group_column = channel / group_row; + // count is the product of NCHW same as numel() + int count = num * group_column * group_row * sp_sz; + + int blocks = NumBlocks(output->numel()); + int threads = kNumCUDAThreads; + + const T* input_data = input->data(); + T* output_data = output->mutable_data(ctx.GetPlace()); + + ShuffleChannel< + T><<>>( + count, feature_map_size, output_data, input_data, group_row, + group_column, sp_sz); + } +}; + +template +class ShuffleChannelGradOpCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* input = ctx.Input("X"); + int group = ctx.Attr("group"); + auto input_dims = input->dims(); + auto num = input_dims[0]; + auto channel = input_dims[1]; + auto height = input_dims[2]; + auto weight = input_dims[3]; + auto feature_map_size = channel * height * weight; + auto sp_sz = height * weight; + + int group_row = group; + int group_column = channel / group_row; + auto* output_grad = + ctx.Input(framework::GradVarName("Out")); + auto* input_grad = + ctx.Output(framework::GradVarName("X")); + T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); + const T* output_grad_data = output_grad->data(); + + int blocks = NumBlocks(output_grad->numel()); + int threads = kNumCUDAThreads; + int count = num * group_column * group_row * sp_sz; + ShuffleChannel< + T><<>>( + count, feature_map_size, input_grad_data, output_grad_data, group_row, + group_column, sp_sz); + } +}; +} // namespace operators +} // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - shufflechannel, - ops::ShuffleChannelOpKernel - ops::ShuffleChannelOpKernel, + ops::ShuffleChannelOpCUDAKernel); REGISTER_OP_CUDA_KERNEL( - shufflechannel_grad, - ops::ShuffleChannelOpGradKernel - ops::ShuffleChannelOpGradKernel, + ops::ShuffleChannelGradOpCUDAKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.h b/paddle/fluid/operators/shuffle_channel_op.h index f923babf5b8f58afaa314114fb8317d07f596e17..5c161c0005da85434634b2b3747b9e295630af02 100644 --- a/paddle/fluid/operators/shuffle_channel_op.h +++ b/paddle/fluid/operators/shuffle_channel_op.h @@ -21,10 +21,10 @@ namespace operators { template class ShuffleChannelOpKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); - auto group = ctx.Input("group"); + int group = ctx.Attr("group"); auto input_dims = input->dims(); auto num = input_dims[0]; @@ -34,21 +34,19 @@ class ShuffleChannelOpKernel : public framework::OpKernel { auto feature_map_size = channel * height * weight; auto sp_sz = height * weight; - int group_row = group; - int group_column = channels / group_row; + int group_column = channel / group_row; const T* input_data = input->data(); - T* output_data = out->mutable_data(ctx.GetPlace()); - + T* output_data = output->mutable_data(ctx.GetPlace()); for (int n = 0; n < num; ++n) { - output_data_temp = output_data + n * feature_map_size; - input_data_temp = input_data + n * feature_map_size; for (int i = 0; i < group_row; ++i) { for (int j = 0; j < group_column; ++j) { - const auto* p_i = input_data_temp + (i * group_column + j) * sp_sz; - auto* p_o = output_data_temp + (j * group_row + i) * sp_sz; - memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); + const T* p_i = input_data + n * feature_map_size + + (i * group_column + j) * sp_sz; + T* p_o = + output_data + n * feature_map_size + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(int) * sp_sz); } } } @@ -61,7 +59,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("X"); - auto group = ctx.Input("group"); + int group = ctx.Attr("group"); auto input_dims = input->dims(); auto num = input_dims[0]; @@ -72,7 +70,7 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel { auto sp_sz = height * weight; int group_row = group; - int group_column = channels / group_row; + int group_column = channel / group_row; auto* output_grad = ctx.Input(framework::GradVarName("Out")); @@ -81,19 +79,17 @@ class ShuffleChannelGradOpKernel : public framework::OpKernel { T* input_grad_data = input_grad->mutable_data(ctx.GetPlace()); const T* output_grad_data = output_grad->data(); - for (int n = 0; n < num; ++n) { - output_grad_temp = output_grad_data + n * feature_map_size; - input_grad_temp = input_grad_data + n * feature_map_size; for (int i = 0; i < group_row; ++i) { for (int j = 0; j < group_column; ++j) { - const auto* p_i = output_grad_temp + (i * group_column + j) * sp_sz; - auto* p_o = input_grad_temp + (j * group_row + i) * sp_sz; - memcpy(p_o, p_i, sizeof(Dtype) * sp_sz); + const T* p_i = output_grad_data + n * feature_map_size + + (i * group_column + j) * sp_sz; + T* p_o = input_grad_data + n * feature_map_size + + (j * group_row + i) * sp_sz; + memcpy(p_o, p_i, sizeof(int) * sp_sz); } } } - return; } }; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 57d210eab8d7bc5eafe7b893819762647570b2c1..fd7cddeffbf8cb2dcd4a2106483a9d2f1d8faca0 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -173,7 +173,7 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'lstm', - 'shufflechannel', + 'shuffle_channel', 'psroi_pool', ] @@ -9334,17 +9334,20 @@ def shuffle_channel(x, group=1, name=None): with multiple group convolutional layers. Args: - x: The input tensor variable. + x: The input tensor variable.. + group: The num of group Returns: Variable: channel shuffled tensor variable. Raises: - ValueError: If group in not a int type variable. + ValueError: If group in not an int type variable. Examples: .. code-block:: python + + out = fluid.layers.shuffle_channel(x=group_conv,group=4) """ @@ -9361,6 +9364,7 @@ def shuffle_channel(x, group=1, name=None): inputs={"X": x}, outputs={"Out": out}, attrs={"group": group}) + return out @templatedoc() diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index e2edba030b768080595eb3d01c726235d54de418..7ade135ec3dbb49ba95fc66100c47f7f7c64658b 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1018,7 +1018,7 @@ class TestBook(unittest.TestCase): def test_shuffle_channel(self): program = Program() with program_guard(program): - x = layers.data(name="x", shape=[10, 32, 16, 16], dtype="float32") + x = layers.data(name="x", shape=[1, 4, 2, 2], dtype="float32") group = layers.data(name="group", shape=[1], dtype="int32") out = layers.shuffle_channel(x, group) self.assertIsNotNone(out) diff --git a/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py index 25df22193ca3e7548fe7ab28ab56b869a8de3c17..4fabe424fa77dfb9b82af488ef971ac9edaa09b1 100644 --- a/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py +++ b/python/paddle/fluid/tests/unittests/test_shuffle_channel_op.py @@ -23,31 +23,29 @@ import paddle.fluid.core as core class TestShuffleChannelOp(OpTest): - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'output') - def setUp(self): self.op_type = "shuffle_channel" - self.batch_size = 10 - self.input_channels = 16 - self.layer_h = 32 - self.layer_w = 32 - self.group = 4 - + self.batch_size = 1 + self.input_channels = 4 + self.layer_h = 2 + self.layer_w = 2 + self.group = 2 self.x = np.random.random( - (self.batch_size, self.input_channels, self.layer_h, self, - layer_w)).astype('float32') + (self.batch_size, self.input_channels, self.layer_h, + self.layer_w)).astype('float32') self.inputs = {'X': self.x} self.attrs = {'group': self.group} - n, c, h, w = self.x.shape input_reshaped = np.reshape(self.x, (-1, self.group, c // self.group, h, w)) input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4)) - self.outputs = np.reshape(input_transposed, (-1, c, h, w)) + self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') if __name__ == '__main__':