提交 a065adb8 编写于 作者: U Unknown 提交者: liutuo

add depth_to_space kernel / test/benchmark

上级 37487e1f
//
// Created by liutuo on 18-3-20.
//
#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H
#define MACE_KERNELS_DEPTH_TO_SPACE_H
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct DepthToSpaceOpFunctor {
DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
std::vector<index_t> output_shape(input->shape());
const int batch_size = input->dim(0);
const int input_height = input->dim(1);
const int input_width = input->dim(2);
const int input_depth = input->dim(3);
const int block_size_sq = block_size_ * block_size_;
const index_t output_depth = input_depth / block_size_sq;
const index_t output_width = input_width * block_size_;
const index_t output_height = input_height * block_size_;
output_shape[0] = batch_size;
output_shape[1] = output_height;
output_shape[2] = output_width;
output_shape[3] = output_depth;
output->Resize(output_shape);
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < output_height; ++h) {
const int in_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < output_width; ++w) {
const int in_w = w / block_size_;
const int offset_w = w % block_size_;
const int offset_d = (offset_h * block_size_ + offset_w) * output_depth;
for (int d = 0; d < output_depth; ++d) {
const int in_d = d + offset_d;
const int o_index = ((b * output_height + h) * output_width + w) * output_depth + d;
const int i_index = ((b * input_height + in_h) * input_width + in_w) * input_depth + in_d;
output_ptr[o_index] = input[i_index];
}
}
}
}
}
const int block_size_;
};
template <typename T>
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
const int block_size_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif //MACE_KERNELS_DEPTH_TO_SPACE_H
#include <common.h>
// assume channes_per_group mod 4 = 0 && groups mod 4 == 0
__kernel void channel_shuffle(__read_only image2d_t input,
__private const int groups,
__private const int channels_per_group,
__write_only image2d_t output) {
const int group_chan_blk_idx = get_global_id(0);
const int width_idx = get_global_id(1);
const int width = get_global_size(1);
const int hb_idx = get_global_id(2);
const int group_blks = groups / 4;
const int groups_blks_width = group_blks * width;
const int channels_per_group_blks = channels_per_group / 4;
const int channels_per_group_blks_width = channels_per_group_blks * width;
DATA_TYPE4 in_chan_data0, in_chan_data1, in_chan_data2, in_chan_data3;
DATA_TYPE4 out_chan_data0, out_chan_data1, out_chan_data2, out_chan_data3;
int in_x = mad24(group_chan_blk_idx, width, width_idx);
for (short g_blk = 0; g_blk < group_blks; ++g_blk) {
// fetch 4 groups, for each group fetch 4 channels
in_chan_data0 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data1 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data2 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
in_chan_data3 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx));
in_x += channels_per_group_blks_width;
out_chan_data0 = (DATA_TYPE4)(in_chan_data0.x, in_chan_data1.x, in_chan_data2.x, in_chan_data3.x);
out_chan_data1 = (DATA_TYPE4)(in_chan_data0.y, in_chan_data1.y, in_chan_data2.y, in_chan_data3.y);
out_chan_data2 = (DATA_TYPE4)(in_chan_data0.z, in_chan_data1.z, in_chan_data2.z, in_chan_data3.z);
out_chan_data3 = (DATA_TYPE4)(in_chan_data0.w, in_chan_data1.w, in_chan_data2.w, in_chan_data3.w);
int out_x = mad24(mad24(group_chan_blk_idx, groups, g_blk), width, width_idx);
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data0);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data1);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data2);
out_x += groups_blks_width;
WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data3);
}
}
......@@ -16,58 +16,34 @@ namespace ops {
template <DeviceType D, typename T>
class DepthToSpaceOp : public Operator<D, T> {
public:
public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetSingleArgument<int>("block_size", 1),
true) {}
block_size_(OperatorBase::GetSingleArgument<int>("block_size", 1)),
functor_(this->block_size_) {}
bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor = this->Output(OUTPUT);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
std::vector<index_t> output_shape(4, 0);
CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor),
future);
return true;
}
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape =
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(crops.size() == 4, "Crops' shape should be 2D");
int input_depth = input->dim(3);
MACE_CHECK(input_depth % (block_size_ * block_size_) == 0,
"input depth should be dividable by block_size * block_size",
input->dim(3));
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1,
"block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t cropped_input_size =
input_tensor->dim(block_dim + 1) * block_shape_value -
crops[block_dim * 2] - crops[block_dim * 2 + 1];
MACE_CHECK(cropped_input_size >= 0, "cropped size must be non-negative");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = cropped_input_size;
}
output_shape[0] = input_tensor->dim(0) / block_shape_product;
output_shape[3] = input_tensor->dim(3);
functor_(input, output, future);
return true;
}
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
protected:
const int block_size_;
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
template <DeviceType D, typename T>
static void DepthToSpace(
int iters, int batch, int channels, int height, int width, int block_size) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("InputImage")
.Output("Output")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("DepthToSpace", "DepthToSpaceBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void \
BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthToSpace<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \
BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_DEPTH_TO_SPACE(N, C, H, W, G) \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, CPU); \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, OPENCL); \
BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, half, OPENCL);
BM_DEPTH_TO_SPACE(1, 64, 64, 64, 8);
BM_DEPTH_TO_SPACE(1, 64, 128, 128, 8);
BM_DEPTH_TO_SPACE(1, 64, 256, 256, 8);
} // namespace test
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class DepthToSpaceOpTest : public OpsTestBase {};
TEST_F(DepthToSpaceOpTest, C8G4_CPU) {
// Construct graph
OpsTestNet net;
OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("Input")
.Output("Output")
.AddIntArg("block_size", 1)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>(
{1, 1, 2, 8}, {0, 2, 4, 6, 1, 3, 5, 7, 8, 10, 12, 14, 9, 11, 13, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(DepthToSpaceOpTest, C16G4_OPENCL) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::OPENCL, float>(
"Input", {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
BufferToImage<DeviceType::OPENCL, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("block_size", 1)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::OPENCL);
// Transfer output
ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
// Check
auto expected = CreateTensor<float>(
{1, 1, 2, 16},
{0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15,
16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册