diff --git a/mace/kernels/channel_shuffle.h b/mace/kernels/channel_shuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..49b12661de68c216a207bc0c6ee59961dba185da --- /dev/null +++ b/mace/kernels/channel_shuffle.h @@ -0,0 +1,48 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_CHANNEL_SHUFFLE_H_ +#define MACE_KERNELS_CHANNEL_SHUFFLE_H_ + +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +class ChannelShuffleFunctor { + public: + ChannelShuffleFunctor(const int group) + : group_(group) {} + + void operator()(const T *input, const index_t *input_shape, T *output) { + index_t batch = input_shape[0]; + index_t channels = input_shape[1]; + index_t height = input_shape[2]; + index_t width = input_shape[3]; + + index_t image_size = height * width; + int channels_of_group = channels / group_; + + for (int b = 0; b < batch; ++b) { + for (int c = 0; c < channels_of_group; ++c) { + for (int g = 0; g < group_; ++g) { + index_t input_offset = (b * channels + g * channels_of_group + c) * + image_size; + index_t output_offset = (b * channels + c * group_ + g) * image_size; + memcpy(output + output_offset, input + input_offset, + image_size * sizeof(T)); + } + } + } + } + + private: + const int group_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_CHANNEL_SHUFFLE_H_ \ No newline at end of file diff --git a/mace/ops/channel_shuffle.cc b/mace/ops/channel_shuffle.cc new file mode 100644 index 0000000000000000000000000000000000000000..e76a091c251d01699fe9cc3b9bbdde1791541d82 --- /dev/null +++ b/mace/ops/channel_shuffle.cc @@ -0,0 +1,11 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/channel_shuffle.h" + +namespace mace { + +REGISTER_CPU_OPERATOR(ChannelShuffle, ChannelShuffleOp); + +} // namespace mace diff --git a/mace/ops/channel_shuffle.h b/mace/ops/channel_shuffle.h new file mode 100644 index 0000000000000000000000000000000000000000..3393efdbe28e03509067f9eac94821a683e694e6 --- /dev/null +++ b/mace/ops/channel_shuffle.h @@ -0,0 +1,48 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_CHANNEL_SHUFFLE_H_ +#define MACE_OPS_CHANNEL_SHUFFLE_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/channel_shuffle.h" + +namespace mace { + +template +class ChannelShuffleOp : public Operator { + public: + ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + group_(OperatorBase::GetSingleArgument("group", 1)), + functor_(this->group_) {} + + bool Run() override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->shape()[1] % group_ == 0, + "input channels must be an integral multiple of group. ", + input->shape()[1]); + + output->ResizeLike(input); + functor_(input->data(), input->shape().data(), + output->mutable_data()); + + return true; + } + + protected: + const int group_; + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::ChannelShuffleFunctor functor_; +}; + +} // namespace mace + +#endif // MACE_OPS_CHANNEL_SHUFFLE_H_ diff --git a/mace/ops/channel_shuffle_benchmark.cc b/mace/ops/channel_shuffle_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..13d426f8874ee426c90c88e632b3f32e5f94acfd --- /dev/null +++ b/mace/ops/channel_shuffle_benchmark.cc @@ -0,0 +1,59 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/channel_shuffle.h" +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; +using namespace mace::kernels; + +template +static void ChannelShuffle(int iters, + int batch, + int channels, + int height, + int width, + int group) { + mace::testing::StopTiming(); + + OpsTestNet net; + OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + // Add input data + net.AddIntArg("group", group); + net.AddRandomInput("Input", {batch, channels, height, width}); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } +} + +#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \ + static void \ + BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::ItemsProcessed(tot); \ + mace::testing::BytesProcessed(tot*(sizeof(float))); \ + ChannelShuffle(iters, N, C, H, W, G); \ + } \ + BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE) + +#define BM_CHANNEL_SHUFFLE(N, C, H, W, G) \ + BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, CPU); + +BM_CHANNEL_SHUFFLE(1, 64, 64, 64, 8); +BM_CHANNEL_SHUFFLE(1, 64, 128, 128, 8); +BM_CHANNEL_SHUFFLE(1, 64, 256, 256, 8); diff --git a/mace/ops/channel_shuffle_test.cc b/mace/ops/channel_shuffle_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9722ab2dd40f9a8be2a293d62843db6b26ea5106 --- /dev/null +++ b/mace/ops/channel_shuffle_test.cc @@ -0,0 +1,37 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +using namespace mace; + +class ChannelShuffleOpTest : public OpsTestBase {}; + +TEST_F(ChannelShuffleOpTest, C8G4) { + // Construct graph + auto& net = test_net(); + OpDefBuilder("ChannelShuffle", "ChannelShuffleTest") + .Input("Input") + .Output("Output") + .Finalize(net.operator_def()); + + net.AddIntArg("group", 4); + + // Add input data + net.AddInputFromArray( + "Input", {1, 8, 1, 2}, + {0, 1, 2, 3, 4, 5, 6, 7, + 8, 9, 10, 11, 12, 13, 14, 15}); + + // Run + net.RunOp(); + + // Check + auto expected = + CreateTensor({1, 8, 1, 2}, + {0, 1, 4, 5, 8, 9, 12, 13, + 2, 3, 6, 7, 10, 11, 14, 15}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); +}