提交 7bfff395 编写于 作者: 吴承辉

Merge branch 'master' into 'master'

Add shuffle opencl & cpu NHWC impl.

See merge request !278
......@@ -12,37 +12,49 @@ namespace mace {
namespace kernels {
template <DeviceType D, typename T>
class ChannelShuffleFunctor {
public:
ChannelShuffleFunctor(const int group) : group_(group) {}
struct ChannelShuffleFunctor {
ChannelShuffleFunctor(const int groups) : groups_(groups) {}
void operator()(const T *input,
const index_t *input_shape,
T *output,
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
index_t batch = input_shape[0];
index_t channels = input_shape[1];
index_t height = input_shape[2];
index_t width = input_shape[3];
index_t image_size = height * width;
int channels_of_group = channels / group_;
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels_of_group; ++c) {
for (int g = 0; g < group_; ++g) {
index_t input_offset =
(b * channels + g * channels_of_group + c) * image_size;
index_t output_offset = (b * channels + c * group_ + g) * image_size;
memcpy(output + output_offset, input + input_offset,
image_size * sizeof(T));
}
output->ResizeLike(input);
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
index_t batch = input->dim(0);
index_t height = input->dim(1);
index_t width = input->dim(2);
index_t channels = input->dim(3);
index_t bhw_fuse = batch * height * width;
int channels_per_group = channels / groups_;
#pragma omp parallel for
for (int bhw = 0; bhw < bhw_fuse; ++bhw) {
for (int c = 0; c < channels; ++c) {
index_t channel_base = bhw * channels;
output_ptr[channel_base + c] =
input_ptr[channel_base + c % groups_ * channels_per_group
+ c / groups_];
}
}
}
private:
const int group_;
const int groups_;
};
template <typename T>
struct ChannelShuffleFunctor<DeviceType::OPENCL, T> {
ChannelShuffleFunctor(const int groups) : groups_(groups) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
const int groups_;
};
} // namespace kernels
......
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/channel_shuffle.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void ChannelShuffleFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
output->ResizeLike(input);
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channels_per_group = channels / groups_;
MACE_CHECK(channels_per_group % 4 == 0,
"channels per group must be multiple of 4");
MACE_CHECK(groups_ % 4 == 0,
"groups must be multiple of 4");
const index_t group_channel_blocks = RoundUpDiv4(channels_per_group);
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("channel_shuffle");
built_options.emplace("-Dchannel_shuffle=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("channel_shuffle", kernel_name, built_options);
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, groups_);
kernel_.setArg(idx++, static_cast<uint32_t>(channels_per_group));
kernel_.setArg(idx++, *(output->opencl_image()));
}
const uint32_t gws[3] = {static_cast<uint32_t>(group_channel_blocks),
static_cast<uint32_t>(width),
static_cast<uint32_t>(height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "channel_shuffle_opencl_kernel_"
<< output->dim(0) << "_"
<< output->dim(1) << "_"
<< output->dim(2) << "_"
<< output->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
template
struct ChannelShuffleFunctor<DeviceType::OPENCL, float>;
template
struct ChannelShuffleFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
#include <common.h>
// assume channes_per_group mod 4 = 0 && groups mod 4 == 0
__kernel void channel_shuffle(__read_only image2d_t input,
__private const int groups,
__private const int channels_per_group,
__write_only image2d_t output) {
const int group_chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int width = get_global_size(1);
const int hb_idx = get_global_id(2);
const int group_blks = groups / 4;
const int groups_blks_width = group_blks * width;
const int channels_per_group_blks = channels_per_group / 4;
const int channels_per_group_blks_width = channels_per_group_blks * width;
DATA_TYPE4 in_chan_data0, in_chan_data1, in_chan_data2, in_chan_data3;
DATA_TYPE4 out_chan_data0, out_chan_data1, out_chan_data2, out_chan_data3;
int in_x = mad24(group_chan_blk_idx, width, width_idx);
for (short g_blk = 0; g_blk < group_blks; ++g_blk) {
// fetch 4 groups, for each group fetch 4 channels
in_chan_data0 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data1 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data2 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data3 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
out_chan_data0 = (DATA_TYPE4)(in_chan_data0.x, in_chan_data1.x, in_chan_data2.x, in_chan_data3.x);
out_chan_data1 = (DATA_TYPE4)(in_chan_data0.y, in_chan_data1.y, in_chan_data2.y, in_chan_data3.y);
out_chan_data2 = (DATA_TYPE4)(in_chan_data0.z, in_chan_data1.z, in_chan_data2.z, in_chan_data3.z);
out_chan_data3 = (DATA_TYPE4)(in_chan_data0.w, in_chan_data1.w, in_chan_data2.w, in_chan_data3.w);
int out_x = mad24(mad24(group_chan_blk_idx, groups, g_blk), width, width_idx);
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data0);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data1);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data2);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data3);
}
}
......@@ -12,6 +12,16 @@ void Register_ChannelShuffle(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T")
.Build(),
ChannelShuffleOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ChannelShuffle")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
ChannelShuffleOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("ChannelShuffle")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
ChannelShuffleOp<DeviceType::OPENCL, half>);
}
} // namespace mace
......@@ -23,13 +23,12 @@ class ChannelShuffleOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->shape()[1] % group_ == 0,
int channels = input->dim(3);
MACE_CHECK(channels % group_ == 0,
"input channels must be an integral multiple of group. ",
input->shape()[1]);
output->ResizeLike(input);
functor_(input->data<T>(), input->shape().data(), output->mutable_data<T>(),
future);
input->dim(3));
int channels_per_group = channels / group_;
functor_(input, output, future);
return true;
}
......
......@@ -2,54 +2,68 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/channel_shuffle.h"
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
using namespace mace;
using namespace mace::kernels;
namespace mace {
template <DeviceType D>
template<DeviceType D, typename T>
static void ChannelShuffle(
int iters, int batch, int channels, int height, int width, int group) {
int iters, int batch, int channels, int height, int width, int group) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("GlobalAvgPooling", "GlobalAvgPoolingTest")
.Input("Input")
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("InputImage")
.Output("Output")
.AddIntArg("group", group)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input",
{batch, channels, height, width});
} else {
OpDefBuilder("Softmax", "SoftmaxBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, DEVICE) \
static void BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE( \
#define BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(float))); \
ChannelShuffle<DEVICE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##DEVICE)
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ChannelShuffle<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_CHANNEL_SHUFFLE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_CHANNEL_SHUFFLE(N, C, H, W, G) \
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, CPU);
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, float, CPU); \
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, float, OPENCL); \
BM_CHANNEL_SHUFFLE_MACRO(N, C, H, W, G, half, OPENCL);
BM_CHANNEL_SHUFFLE(1, 64, 64, 64, 8);
BM_CHANNEL_SHUFFLE(1, 64, 128, 128, 8);
BM_CHANNEL_SHUFFLE(1, 64, 256, 256, 8);
} // namespace mace
......@@ -8,7 +8,7 @@ using namespace mace;
class ChannelShuffleOpTest : public OpsTestBase {};
TEST_F(ChannelShuffleOpTest, C8G4) {
TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
// Construct graph
OpsTestNet net;
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
......@@ -19,7 +19,7 @@ TEST_F(ChannelShuffleOpTest, C8G4) {
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 8, 1, 2},
"Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
// Run
......@@ -27,7 +27,41 @@ TEST_F(ChannelShuffleOpTest, C8G4) {
// Check
auto expected = CreateTensor<float>(
{1, 8, 1, 2}, {0, 1, 4, 5, 8, 9, 12, 13, 2, 3, 6, 7, 10, 11, 14, 15});
{1, 1, 2, 8}, {0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(ChannelShuffleOpTest, C16G4_OPENCL) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::OPENCL, float>(
"Input", {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("group", 4)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::OPENCL);
// Transfer output
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
// Check
auto expected = CreateTensor<float>(
{1, 1, 2, 16}, {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15,
16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册