From 24f577262179d0057705d2466d402e5ad625d106 Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 21 Mar 2018 18:53:26 +0800 Subject: [PATCH] add space to depth and pass test and benchmark add test and benchmark to depth_to_space. add space_to_depth ops add test and benchmark to space_to_depth and pass the test and bm --- mace/core/operator.cc | 2 + mace/kernels/depth_to_space.h | 27 +------ mace/kernels/opencl/cl/depth_to_space.cl | 67 +++++------------ mace/kernels/opencl/cl/space_to_depth.cl | 25 +++++++ mace/kernels/opencl/depth_to_space_opencl.cc | 77 ++++++++++++++++++++ mace/kernels/opencl/space_to_depth_opencl.cc | 77 ++++++++++++++++++++ mace/kernels/space_to_depth.h | 76 +++++++++++++++++++ mace/ops/depth_to_space.cc | 4 +- mace/ops/depth_to_space.h | 2 - mace/ops/depth_to_space_benchmark.cc | 6 +- mace/ops/depth_to_space_test.cc | 17 +++-- mace/ops/space_to_depth.cc | 32 ++++++++ mace/ops/space_to_depth.h | 53 ++++++++++++++ mace/ops/space_to_depth_benchmark.cc | 74 +++++++++++++++++++ mace/ops/space_to_depth_test.cc | 77 ++++++++++++++++++++ 15 files changed, 532 insertions(+), 84 deletions(-) create mode 100644 mace/kernels/opencl/cl/space_to_depth.cl create mode 100644 mace/kernels/opencl/depth_to_space_opencl.cc create mode 100644 mace/kernels/opencl/space_to_depth_opencl.cc create mode 100644 mace/kernels/space_to_depth.h create mode 100644 mace/ops/space_to_depth.cc create mode 100644 mace/ops/space_to_depth.h create mode 100644 mace/ops/space_to_depth_benchmark.cc create mode 100644 mace/ops/space_to_depth_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 710806f8..45d29cce 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,6 +83,7 @@ extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); +extern void Register_SpaceToDepth(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); extern void Register_WinogradTransform(OperatorRegistry *op_registry); extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry); @@ -113,6 +114,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ResizeBilinear(this); ops::Register_Softmax(this); ops::Register_SpaceToBatchND(this); + ops::Register_SpaceToDepth(this); ops::Register_MatMul(this); ops::Register_WinogradTransform(this); ops::Register_WinogradInverseTransform(this); diff --git a/mace/kernels/depth_to_space.h b/mace/kernels/depth_to_space.h index 8dfdce0b..0161d83f 100644 --- a/mace/kernels/depth_to_space.h +++ b/mace/kernels/depth_to_space.h @@ -23,16 +23,7 @@ struct DepthToSpaceOpFunctor { const int input_width = input->dim(2); const int input_depth = input->dim(3); - std::cout << "input shape: {" << batch_size <<", "; - std::cout << input_height << ", "; - std::cout << input_width << ", "; - std::cout << input_depth << ", "; - - std::cout << "block size= " << block_size_<Resize(output_shape); - // Tensor::MappingGuard logits_guard(input); - // Tensor::MappingGuard output_guard(output); + Tensor::MappingGuard logits_guard(input); + Tensor::MappingGuard output_guard(output); const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); @@ -74,12 +60,7 @@ struct DepthToSpaceOpFunctor { } const int block_size_; }; -/* -template <> -void DepthToSpaceOpFunctor::operator()(const Tensor *input, - Tensor *output, - StatsFuture *future); -*/ + template struct DepthToSpaceOpFunctor { diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl index 69ddfdba..238526c9 100644 --- a/mace/kernels/opencl/cl/depth_to_space.cl +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -1,55 +1,28 @@ #include -__kernel void depth_to-space(__read_only image2d_t input, +__kernel void depth_to_space(__read_only image2d_t input, __private const int block_size, - __private const int batch_size, - __private const int input_height, - __private const int input_width, - __private const int input_depth, - __private const int output_height, - __private const int output_width, __private const int output_depth, __write_only image2d_t output) { - const int ch_blk = get_global_id(0); - const int w = get_global_id(1); - const int hb = get_global_id(2); - const int width = get_global_size(1); - - const int out_idx = mad24(ch_blk, width, w); + const int out_d = get_global_id(0); + const int out_w = get_global_id(1); + const int out_h = get_global_id(2); + const int output_width = get_global_size(1); - const int d = out_idx % output_depth; - const int out_idx2 = out_idx / output_depth; - const int w = out_idx2 % output_width + const int out_pos = mad24(out_d, output_width, out_w); - for (short g_blk = 0; g_blk < group_blks; ++g_blk) { - // fetch 4 groups, for each group fetch 4 channels - in_chan_data0 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx)); - in_x += channels_per_group_blks_width; - - in_chan_data1 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx)); - in_x += channels_per_group_blks_width; - - in_chan_data2 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx)); - in_x += channels_per_group_blks_width; - - in_chan_data3 = READ_IMAGET(input, SAMPLER, (int2)(in_x, hb_idx)); - in_x += channels_per_group_blks_width; - - out_chan_data0 = (DATA_TYPE4)(in_chan_data0.x, in_chan_data1.x, in_chan_data2.x, in_chan_data3.x); - out_chan_data1 = (DATA_TYPE4)(in_chan_data0.y, in_chan_data1.y, in_chan_data2.y, in_chan_data3.y); - out_chan_data2 = (DATA_TYPE4)(in_chan_data0.z, in_chan_data1.z, in_chan_data2.z, in_chan_data3.z); - out_chan_data3 = (DATA_TYPE4)(in_chan_data0.w, in_chan_data1.w, in_chan_data2.w, in_chan_data3.w); - - int out_x = mad24(mad24(group_chan_blk_idx, groups, g_blk), width, width_idx); - WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data0); - out_x += groups_blks_width; - - WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data1); - out_x += groups_blks_width; - - WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data2); - out_x += groups_blks_width; - - WRITE_IMAGET(output, (int2)(out_x, hb_idx), out_chan_data3); - } + const int input_width = output_width / block_size; + + const int in_h = out_h / block_size; + const int offset_h = out_h % block_size; + const int in_w = out_w / block_size; + const int offset_w = out_w % block_size; + + const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int in_d = out_d + offset_d; + + const int in_pos = mad24(in_d, input_width, in_w); + + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); + WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); } diff --git a/mace/kernels/opencl/cl/space_to_depth.cl b/mace/kernels/opencl/cl/space_to_depth.cl new file mode 100644 index 00000000..b54ee295 --- /dev/null +++ b/mace/kernels/opencl/cl/space_to_depth.cl @@ -0,0 +1,25 @@ +#include + +__kernel void space_to_depth(__read_only image2d_t input, + __private const int block_size, + __private const int input_depth, + __write_only image2d_t output) { + const int d = get_global_id(0); + const int w = get_global_id(1); + const int h = get_global_id(2); + const int input_width = get_global_size(1); + const int in_pos = mad24(d, input_width, w); + const int output_width = input_width / block_size; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + const int out_pos = mad24(out_d, output_width, out_w); + + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); + + WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); +} diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc new file mode 100644 index 00000000..322a7d80 --- /dev/null +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2018 XiaoMi All rights reserved. +// + +#include "mace/kernels/depth_to_space.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void DepthToSpaceOpFunctor::operator()( + const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t input_h = input->dim(1); + const index_t input_w = input->dim(2); + const index_t input_d = input->dim(3); + + const index_t output_h = input_h * block_size_; + const index_t output_w = input_w * block_size_; + const index_t output_d = input_d / (block_size_ * block_size_); + + std::vector output_shape = {batch, output_h, output_w, output_d}; + + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); + output->ResizeImage(output_shape, image_shape); + + const int output_depth_blocks = RoundUpDiv4(output_d); + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depth_to_space"); + built_options.emplace("-Ddepth_to_space=" + kernel_name); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + kernel_ = runtime->BuildKernel("depth_to_space", kernel_name, + built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, block_size_); + kernel_.setArg(idx++, output_depth_blocks); + kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); + } + + const uint32_t gws[3] = {static_cast(output_depth_blocks), + static_cast(output_w), + static_cast(output_h * batch)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "depth_to_space_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); +} + +template +struct DepthToSpaceOpFunctor; +template +struct DepthToSpaceOpFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/space_to_depth_opencl.cc b/mace/kernels/opencl/space_to_depth_opencl.cc new file mode 100644 index 00000000..e5023104 --- /dev/null +++ b/mace/kernels/opencl/space_to_depth_opencl.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2018 XiaoMi All rights reserved. +// + +#include "mace/kernels/space_to_depth.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void SpaceToDepthOpFunctor::operator()( + const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch_size = input->dim(0); + const index_t input_height = input->dim(1); + const index_t input_width = input->dim(2); + const index_t input_depth = input->dim(3); + + const index_t output_height = input_height / block_size_; + const index_t output_width = input_width / block_size_; + const index_t output_depth = input_depth * block_size_ * block_size_; + + std::vector output_shape = {batch_size, output_height, output_width, output_depth}; + + std::vector image_shape; + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); + output->ResizeImage(output_shape, image_shape); + + const int input_depth_blocks = RoundUpDiv4(input_depth); + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("space_to_depth"); + built_options.emplace("-Dspace_to_depth=" + kernel_name); + auto dt = DataTypeToEnum::value; + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + kernel_ = runtime->BuildKernel("space_to_depth", kernel_name, + built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, block_size_); + kernel_.setArg(idx++, input_depth_blocks); + kernel_.setArg(idx++, *(output->opencl_image())); + + input_shape_ = input->shape(); + } + + const uint32_t gws[3] = {static_cast(input_depth_blocks), + static_cast(input_width), + static_cast(input_height * batch_size)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "space_to_depth_opencl_kernel_" + << input->dim(0) << "_" + << input->dim(1) << "_" + << input->dim(2) << "_" + << input->dim(3); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); +} + +template +struct SpaceToDepthOpFunctor; +template +struct SpaceToDepthOpFunctor; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/space_to_depth.h b/mace/kernels/space_to_depth.h new file mode 100644 index 00000000..b3125901 --- /dev/null +++ b/mace/kernels/space_to_depth.h @@ -0,0 +1,76 @@ +// +// Created by liutuo on 18-3-20. +// + +#ifndef MACE_KERNELS_SPACE_TO_DEPTH_H +#define MACE_KERNELS_SPACE_TO_DEPTH_H + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +struct SpaceToDepthOpFunctor { + explicit SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {} + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + + const int batch_size = input->dim(0); + const int input_height = input->dim(1); + const int input_width = input->dim(2); + const int input_depth = input->dim(3); + + const index_t output_depth = input_depth * block_size_ * block_size_; + const index_t output_width = input_width / block_size_; + const index_t output_height = input_height / block_size_; + + std::vector output_shape = {batch_size, output_height, output_width, output_depth}; + + output->Resize(output_shape); + + Tensor::MappingGuard logits_guard(input); + Tensor::MappingGuard output_guard(output); + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < input_height; ++h) { + const int out_h = h / block_size_; + const int offset_h = (h % block_size_); + for (int w = 0; w < input_width; ++w) { + const int out_w = w/ block_size_; + const int offset_w = (w % block_size_); + const int offset_d = (offset_h * block_size_ + offset_w) * input_depth; + for (int d = 0; d < input_depth; ++d) { + const int out_d = d + offset_d; + const int o_index = ((b * output_height + out_h) * output_width + out_w) * output_depth + out_d; + const int i_index = ((b * input_height + h) * input_width + w) * input_depth + d; + output_ptr[o_index] = input_ptr[i_index]; + } + } + } + } + + } + const int block_size_; +}; + +template +struct SpaceToDepthOpFunctor { + + SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {} + void operator()(const Tensor *input, Tensor *output, StatsFuture *future); + + cl::Kernel kernel_; + const int block_size_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif //MACE_KERNELS_SPACE_TO_DEPTH_H diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc index e390a812..cfea8e8d 100644 --- a/mace/ops/depth_to_space.cc +++ b/mace/ops/depth_to_space.cc @@ -13,7 +13,7 @@ void Register_DepthToSpace(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), DepthToSpaceOp); -/* + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") @@ -25,7 +25,7 @@ void Register_DepthToSpace(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), DepthToSpaceOp); -*/ + } } // namespace ops diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h index 808cb715..979a08ed 100644 --- a/mace/ops/depth_to_space.h +++ b/mace/ops/depth_to_space.h @@ -32,11 +32,9 @@ class DepthToSpaceOp : public Operator { MACE_CHECK(input_depth % (block_size * block_size) == 0, "input depth should be dividable by block_size * block_size", input->dim(3)); - std::cout << "arg block_size: " << block_size << std::endl; functor_(input, output, future); return true; } - protected: OP_INPUT_TAGS(INPUT); diff --git a/mace/ops/depth_to_space_benchmark.cc b/mace/ops/depth_to_space_benchmark.cc index e33349a6..beb1cc60 100644 --- a/mace/ops/depth_to_space_benchmark.cc +++ b/mace/ops/depth_to_space_benchmark.cc @@ -65,9 +65,9 @@ static void DepthToSpace( BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, float, OPENCL); \ BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, half, OPENCL); -BM_DEPTH_TO_SPACE(1, 64, 64, 64, 8); -BM_DEPTH_TO_SPACE(1, 64, 128, 128, 8); -BM_DEPTH_TO_SPACE(1, 64, 256, 256, 8); +BM_DEPTH_TO_SPACE(1, 64, 64, 64, 4); +BM_DEPTH_TO_SPACE(1, 64, 128, 128, 4); +BM_DEPTH_TO_SPACE(1, 64, 256, 256, 4); } // namespace test } // namespace ops diff --git a/mace/ops/depth_to_space_test.cc b/mace/ops/depth_to_space_test.cc index bbbb39b5..6ad830e6 100644 --- a/mace/ops/depth_to_space_test.cc +++ b/mace/ops/depth_to_space_test.cc @@ -22,15 +22,18 @@ TEST_F(DepthToSpaceOpTest, C8G4_CPU) { // Add input data net.AddInputFromArray( - "Input", {1, 2, 2, 4}, - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + "Input", {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); // Run net.RunOp(); // Check auto expected = CreateTensor( - {1, 4, 4, 1}, {1, 2, 5, 6, 3, 4, 7, 8, 9, 10, 13, 14, 11, 12, 15, 16}); + {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } @@ -50,7 +53,7 @@ TEST_F(DepthToSpaceOpTest, C16G4_OPENCL) { OpDefBuilder("DepthToSpace", "DepthToSpaceTest") .Input("InputImage") .Output("OutputImage") - .AddIntArg("block_size", 1) + .AddIntArg("block_size", 2) .Finalize(net.NewOperatorDef()); // Run @@ -62,9 +65,9 @@ TEST_F(DepthToSpaceOpTest, C16G4_OPENCL) { // Check auto expected = CreateTensor( - {1, 1, 2, 16}, - {0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15, - 16, 20, 24, 28, 17, 21, 25, 29, 18, 22, 26, 30, 19, 23, 27, 31}); + {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } diff --git a/mace/ops/space_to_depth.cc b/mace/ops/space_to_depth.cc new file mode 100644 index 00000000..69630075 --- /dev/null +++ b/mace/ops/space_to_depth.cc @@ -0,0 +1,32 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/space_to_depth.h" + +namespace mace { +namespace ops { + +void Register_SpaceToDepth(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + SpaceToDepthOp); + +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/space_to_depth.h b/mace/ops/space_to_depth.h new file mode 100644 index 00000000..b21eeb49 --- /dev/null +++ b/mace/ops/space_to_depth.h @@ -0,0 +1,53 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SPACE_TO_DEPTH_H_ +#define MACE_OPS_SPACE_TO_DEPTH_H_ + +#include +#include + +#include "mace/core/operator.h" +#include "mace/kernels/space_to_depth.h" + +namespace mace { +namespace ops { + +template +class SpaceToDepthOp : public Operator { + public: + SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("block_size", 1)) {} + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); + + const int block_size = OperatorBase::GetSingleArgument("block_size", 1); + + const int input_height = input->dim(1); + const int input_width = input->dim(2); + const int input_depth = input->dim(3); + MACE_CHECK((input_width % block_size == 0) && (input_height % block_size == 0), + "input width and height should be dividable by block_size", + input->dim(3)); + functor_(input, output, future); + return true; + } + + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::SpaceToDepthOpFunctor functor_; + +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SPACE_TO_DEPTH_H_ diff --git a/mace/ops/space_to_depth_benchmark.cc b/mace/ops/space_to_depth_benchmark.cc new file mode 100644 index 00000000..f4d78898 --- /dev/null +++ b/mace/ops/space_to_depth_benchmark.cc @@ -0,0 +1,74 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void SpaceToDepth( + int iters, int batch, int channels, int height, int width, int block_size) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("SpaceToDepth", "SpaceToDepthBM") + .Input("InputImage") + .Output("Output") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("SpaceToDepth", "SpaceToDepthBM") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \ + static void \ + BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SpaceToDepth(iters, N, C, H, W, G); \ + } \ + BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) + +#define BM_SPACE_TO_DEPTH(N, C, H, W, G) \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, CPU); \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, float, OPENCL); \ + BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, half, OPENCL); + +BM_SPACE_TO_DEPTH(1, 64, 64, 64, 4); +BM_SPACE_TO_DEPTH(1, 64, 128, 128, 4); +BM_SPACE_TO_DEPTH(1, 64, 256, 256, 4); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/space_to_depth_test.cc b/mace/ops/space_to_depth_test.cc new file mode 100644 index 00000000..37d020a9 --- /dev/null +++ b/mace/ops/space_to_depth_test.cc @@ -0,0 +1,77 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class SpaceToDepthOpTest : public OpsTestBase {}; + +TEST_F(SpaceToDepthOpTest, C8G4_CPU) { + // Construct graph + OpsTestNet net; + OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") + .Input("Input") + .Output("Output") + .AddIntArg("block_size", 2) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray( + "Input", {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); + + // Run + net.RunOp(); + + // Check + auto expected = CreateTensor( + {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); +} + +TEST_F(SpaceToDepthOpTest, C16G4_OPENCL) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray( + "Input", {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") + .Input("InputImage") + .Output("OutputImage") + .AddIntArg("block_size", 2) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(DeviceType::OPENCL); + + // Transfer output + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + + // Check + auto expected = CreateTensor( + {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); +} + +} // namespace test +} // namespace ops +} // namespace mace -- GitLab