提交 1dd91c6d 编写于 作者: 李寅

Merge branch 'feature_wuch' into 'master'

add channel shuffle op

See merge request !57
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_CHANNEL_SHUFFLE_H_
#define MACE_KERNELS_CHANNEL_SHUFFLE_H_
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
class ChannelShuffleFunctor {
public:
ChannelShuffleFunctor(const int group)
: group_(group) {}
void operator()(const T *input, const index_t *input_shape, T *output) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
index_t image_size = height * width;
int channels_of_group = channels / group_;
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels_of_group; ++c) {
for (int g = 0; g < group_; ++g) {
index_t input_offset = (b * channels + g * channels_of_group + c) *
image_size;
index_t output_offset = (b * channels + c * group_ + g) * image_size;
memcpy(output + output_offset, input + input_offset,
image_size * sizeof(T));
}
}
}
}
private:
const int group_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_CHANNEL_SHUFFLE_H_
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/channel_shuffle.h"
namespace mace {
REGISTER_CPU_OPERATOR(ChannelShuffle, ChannelShuffleOp<DeviceType::CPU, float>);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_CHANNEL_SHUFFLE_H_
#define MACE_OPS_CHANNEL_SHUFFLE_H_
#include <memory>
#include "mace/core/operator.h"
#include "mace/kernels/channel_shuffle.h"
namespace mace {
template<DeviceType D, typename T>
class ChannelShuffleOp : public Operator<D, T> {
public:
ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<D, T>(operator_def, ws),
group_(OperatorBase::GetSingleArgument<int>("group", 1)),
functor_(this->group_) {}
bool Run() override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->shape()[1] % group_ == 0,
"input channels must be an integral multiple of group. ",
input->shape()[1]);
output->ResizeLike(input);
functor_(input->data<T>(), input->shape().data(),
output->mutable_data<T>());
return true;
}
protected:
const int group_;
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::ChannelShuffleFunctor<D, T> functor_;
};
} // namespace mace
#endif // MACE_OPS_CHANNEL_SHUFFLE_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/channel_shuffle.h"
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
using namespace mace::kernels;
template <DeviceType D>
static void ChannelShuffle(int iters,
int batch,
int channels,
int height,
int width,
int group) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
net.AddIntArg("group", group);
net.AddRandomInput<float>("Input", {batch, channels, height, width});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \
static void \
BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot*(sizeof(float))); \
ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE)
#define BM_CHANNEL_SHUFFLE(N, C, H, W, G) \
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, CPU);
BM_CHANNEL_SHUFFLE(1, 64, 64, 64, 8);
BM_CHANNEL_SHUFFLE(1, 64, 128, 128, 8);
BM_CHANNEL_SHUFFLE(1, 64, 256, 256, 8);
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F(ChannelShuffleOpTest, C8G4) {
// Construct graph
auto& net = test_net();
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("Input")
.Output("Output")
.Finalize(net.operator_def());
net.AddIntArg("group", 4);
// Add input data
net.AddInputFromArray<float>(
"Input", {1, 8, 1, 2},
{0, 1, 2, 3, 4, 5, 6, 7,
8, 9, 10, 11, 12, 13, 14, 15});
// Run
net.RunOp();
// Check
auto expected =
CreateTensor<float>({1, 8, 1, 2},
{0, 1, 4, 5, 8, 9, 12, 13,
2, 3, 6, 7, 10, 11, 14, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册