diff --git a/mace/kernels/depth_to_space.h b/mace/kernels/depth_to_space.h index 0161d83f6f9e4739482950fafac702a529bbe62a..3f6577f32159309bba931eaef58011902ecc2045 100644 --- a/mace/kernels/depth_to_space.h +++ b/mace/kernels/depth_to_space.h @@ -1,36 +1,43 @@ // -// Created by liutuo on 18-3-20. +// Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H -#define MACE_KERNELS_DEPTH_TO_SPACE_H +#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H_ +#define MACE_KERNELS_DEPTH_TO_SPACE_H_ +#include #include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/tensor.h" +#include "mace/public/mace.h" namespace mace { namespace kernels { template struct DepthToSpaceOpFunctor { - explicit DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {} - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - std::vector output_shape(input->shape()); + explicit DepthToSpaceOpFunctor(const int block_size, bool d2s) + : block_size_(block_size), d2s_(d2s) {} + void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { const int batch_size = input->dim(0); const int input_height = input->dim(1); const int input_width = input->dim(2); const int input_depth = input->dim(3); - - const index_t output_depth = input_depth / (block_size_ * block_size_); - const index_t output_width = input_width * block_size_; - const index_t output_height = input_height * block_size_; - output_shape[0] = batch_size; - output_shape[1] = output_height; - output_shape[2] = output_width; - output_shape[3] = output_depth; - + + index_t output_depth, output_width, output_height; + + if (d2s_) { + output_depth = input_depth / (block_size_ * block_size_); + output_width = input_width * block_size_; + output_height = input_height * block_size_; + } else { + output_depth = input_depth * block_size_ * block_size_; + output_width = input_width / block_size_; + output_height = input_height / block_size_; + } + std::vector output_shape = {batch_size, output_height, + output_width, output_depth}; + output->Resize(output_shape); Tensor::MappingGuard logits_guard(input); @@ -38,41 +45,75 @@ struct DepthToSpaceOpFunctor { const T *input_ptr = input->data(); T *output_ptr = output->mutable_data(); + if (d2s_) { #pragma omp parallel for - for (int b = 0; b < batch_size; ++b) { - for (int h = 0; h < output_height; ++h) { - const int in_h = h / block_size_; - const int offset_h = (h % block_size_); - for (int w = 0; w < output_width; ++w) { - const int in_w = w / block_size_; - const int offset_w = w % block_size_; - const int offset_d = (offset_h * block_size_ + offset_w) * output_depth; - for (int d = 0; d < output_depth; ++d) { - const int in_d = d + offset_d; - const int o_index = ((b * output_height + h) * output_width + w) * output_depth + d; - const int i_index = ((b * input_height + in_h) * input_width + in_w) * input_depth + in_d; - output_ptr[o_index] = input_ptr[i_index]; + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < output_height; ++h) { + const int in_h = h / block_size_; + const int offset_h = (h % block_size_); + for (int w = 0; w < output_width; ++w) { + const int in_w = w / block_size_; + const int offset_w = w % block_size_; + const int offset_d = + (offset_h * block_size_ + offset_w) * output_depth; + for (int d = 0; d < output_depth; ++d) { + const int in_d = d + offset_d; + const int o_index = + ((b * output_height + h) * output_width + w) * output_depth + + d; + const int i_index = + ((b * input_height + in_h) * input_width + in_w) * + input_depth + + in_d; + output_ptr[o_index] = input_ptr[i_index]; + } + } + } + } + } else { +#pragma omp parallel for + for (int b = 0; b < batch_size; ++b) { + for (int h = 0; h < input_height; ++h) { + const int out_h = h / block_size_; + const int offset_h = (h % block_size_); + for (int w = 0; w < input_width; ++w) { + const int out_w = w / block_size_; + const int offset_w = (w % block_size_); + const int offset_d = + (offset_h * block_size_ + offset_w) * input_depth; + for (int d = 0; d < input_depth; ++d) { + const int out_d = d + offset_d; + const int o_index = + ((b * output_height + out_h) * output_width + out_w) * + output_depth + + out_d; + const int i_index = + ((b * input_height + h) * input_width + w) * input_depth + d; + output_ptr[o_index] = input_ptr[i_index]; + } } } } } - } + const int block_size_; + bool d2s_; }; template struct DepthToSpaceOpFunctor { - - DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {} + DepthToSpaceOpFunctor(const int block_size, bool d2s) + : block_size_(block_size), d2s_(d2s) {} void operator()(const Tensor *input, Tensor *output, StatsFuture *future); cl::Kernel kernel_; const int block_size_; + bool d2s_; std::vector input_shape_; }; } // namespace kernels } // namespace mace -#endif //MACE_KERNELS_DEPTH_TO_SPACE_H +#endif // MACE_KERNELS_DEPTH_TO_SPACE_H_ diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl index 238526c98c893e20d9fa957357ee2471a2e46f73..824f82665542975da3b000d2e0b1865ceabf4a3c 100644 --- a/mace/kernels/opencl/cl/depth_to_space.cl +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -1,28 +1,52 @@ #include __kernel void depth_to_space(__read_only image2d_t input, - __private const int block_size, - __private const int output_depth, - __write_only image2d_t output) { + __private const int block_size, + __private const int output_depth, + __write_only image2d_t output) { const int out_d = get_global_id(0); const int out_w = get_global_id(1); const int out_h = get_global_id(2); const int output_width = get_global_size(1); - - const int out_pos = mad24(out_d, output_width, out_w); - + + const int out_pos = mad24(out_d, output_width, out_w); + const int input_width = output_width / block_size; - - const int in_h = out_h / block_size; + + const int in_h = out_h / block_size; const int offset_h = out_h % block_size; const int in_w = out_w / block_size; const int offset_w = out_w % block_size; - + const int offset_d = (offset_h * block_size + offset_w) * output_depth; const int in_d = out_d + offset_d; - + const int in_pos = mad24(in_d, input_width, in_w); - + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); } + +__kernel void space_to_depth(__read_only image2d_t input, + __private const int block_size, + __private const int input_depth, + __write_only image2d_t output) { + const int d = get_global_id(0); + const int w = get_global_id(1); + const int h = get_global_id(2); + const int input_width = get_global_size(1); + const int in_pos = mad24(d, input_width, w); + const int output_width = input_width / block_size; + + const int out_h = h / block_size; + const int offset_h = h % block_size; + const int out_w = w / block_size; + const int offset_w = w % block_size; + const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int out_d = d + offset_d; + const int out_pos = mad24(out_d, output_width, out_w); + + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); + + WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); +} diff --git a/mace/kernels/opencl/cl/space_to_depth.cl b/mace/kernels/opencl/cl/space_to_depth.cl deleted file mode 100644 index b54ee2954546208aa3360e12aed6d49410a79c1a..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/cl/space_to_depth.cl +++ /dev/null @@ -1,25 +0,0 @@ -#include - -__kernel void space_to_depth(__read_only image2d_t input, - __private const int block_size, - __private const int input_depth, - __write_only image2d_t output) { - const int d = get_global_id(0); - const int w = get_global_id(1); - const int h = get_global_id(2); - const int input_width = get_global_size(1); - const int in_pos = mad24(d, input_width, w); - const int output_width = input_width / block_size; - - const int out_h = h / block_size; - const int offset_h = h % block_size; - const int out_w = w / block_size; - const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; - const int out_d = d + offset_d; - const int out_pos = mad24(out_d, output_width, out_w); - - DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); - - WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); -} diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc index 322a7d80b32ca4bdfcc58913cb11d6b65114fb29..23347c39dde1df961e79ddf5e6581ee29bd54151 100644 --- a/mace/kernels/opencl/depth_to_space_opencl.cc +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -6,72 +6,89 @@ #include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" -#include "mace/utils/utils.h" #include "mace/utils/tuner.h" +#include "mace/utils/utils.h" namespace mace { namespace kernels { template void DepthToSpaceOpFunctor::operator()( - const Tensor *input, - Tensor *output, - StatsFuture *future) { + const Tensor *input, Tensor *output, StatsFuture *future) { const index_t batch = input->dim(0); - const index_t input_h = input->dim(1); - const index_t input_w = input->dim(2); - const index_t input_d = input->dim(3); - - const index_t output_h = input_h * block_size_; - const index_t output_w = input_w * block_size_; - const index_t output_d = input_d / (block_size_ * block_size_); - - std::vector output_shape = {batch, output_h, output_w, output_d}; - + const index_t input_height = input->dim(1); + const index_t input_width = input->dim(2); + const index_t input_depth = input->dim(3); + + index_t output_height, output_width, output_depth; + if (d2s_) { + output_height = input_height * block_size_; + output_width = input_width * block_size_; + output_depth = input_depth / (block_size_ * block_size_); + } else { + output_height = input_height / block_size_; + output_width = input_width / block_size_; + output_depth = input_depth * block_size_ * block_size_; + } + + std::vector output_shape = {batch, output_height, output_width, + output_depth}; + std::vector image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); + CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); output->ResizeImage(output_shape, image_shape); - - const int output_depth_blocks = RoundUpDiv4(output_d); + + const int depth_blocks = + (d2s_) ? RoundUpDiv4(output_depth) : RoundUpDiv4(input_depth); + + const char *kernel_name = (d2s_) ? "depth_to_space" : "space_to_depth"; if (kernel_.get() == nullptr) { auto runtime = OpenCLRuntime::Global(); std::set built_options; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depth_to_space"); - built_options.emplace("-Ddepth_to_space=" + kernel_name); + std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name); + std::stringstream kernel_name_ss; + kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name; + built_options.emplace(kernel_name_ss.str()); auto dt = DataTypeToEnum::value; built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - kernel_ = runtime->BuildKernel("depth_to_space", kernel_name, - built_options); + kernel_ = + runtime->BuildKernel("depth_to_space", kernel_name, built_options); } if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, block_size_); - kernel_.setArg(idx++, output_depth_blocks); + kernel_.setArg(idx++, depth_blocks); kernel_.setArg(idx++, *(output->opencl_image())); - input_shape_ = input->shape(); } - - const uint32_t gws[3] = {static_cast(output_depth_blocks), - static_cast(output_w), - static_cast(output_h * batch)}; - const std::vector lws = {8, 16, 8, 1}; - std::stringstream ss; - ss << "depth_to_space_opencl_kernel_" - << output->dim(0) << "_" - << output->dim(1) << "_" - << output->dim(2) << "_" - << output->dim(3); - TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + + if (d2s_) { + const uint32_t gws[3] = {static_cast(depth_blocks), + static_cast(output_width), + static_cast(output_height * batch)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_" + << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); + + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + } else { + const uint32_t gws[3] = {static_cast(depth_blocks), + static_cast(input_width), + static_cast(input_height * batch)}; + const std::vector lws = {8, 16, 8, 1}; + std::stringstream ss; + ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_" + << input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); + } } -template -struct DepthToSpaceOpFunctor; -template -struct DepthToSpaceOpFunctor; +template struct DepthToSpaceOpFunctor; +template struct DepthToSpaceOpFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/space_to_depth_opencl.cc b/mace/kernels/opencl/space_to_depth_opencl.cc deleted file mode 100644 index e5023104cb54442189866ceaa0c6fec322846cb9..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/space_to_depth_opencl.cc +++ /dev/null @@ -1,77 +0,0 @@ -// -// Copyright (c) 2018 XiaoMi All rights reserved. -// - -#include "mace/kernels/space_to_depth.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/kernels/opencl/helper.h" -#include "mace/utils/utils.h" -#include "mace/utils/tuner.h" - -namespace mace { -namespace kernels { - -template -void SpaceToDepthOpFunctor::operator()( - const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch_size = input->dim(0); - const index_t input_height = input->dim(1); - const index_t input_width = input->dim(2); - const index_t input_depth = input->dim(3); - - const index_t output_height = input_height / block_size_; - const index_t output_width = input_width / block_size_; - const index_t output_depth = input_depth * block_size_ * block_size_; - - std::vector output_shape = {batch_size, output_height, output_width, output_depth}; - - std::vector image_shape; - CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); - output->ResizeImage(output_shape, image_shape); - - const int input_depth_blocks = RoundUpDiv4(input_depth); - - if (kernel_.get() == nullptr) { - auto runtime = OpenCLRuntime::Global(); - std::set built_options; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("space_to_depth"); - built_options.emplace("-Dspace_to_depth=" + kernel_name); - auto dt = DataTypeToEnum::value; - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - kernel_ = runtime->BuildKernel("space_to_depth", kernel_name, - built_options); - } - if (!IsVecEqual(input_shape_, input->shape())) { - uint32_t idx = 0; - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, block_size_); - kernel_.setArg(idx++, input_depth_blocks); - kernel_.setArg(idx++, *(output->opencl_image())); - - input_shape_ = input->shape(); - } - - const uint32_t gws[3] = {static_cast(input_depth_blocks), - static_cast(input_width), - static_cast(input_height * batch_size)}; - const std::vector lws = {8, 16, 8, 1}; - std::stringstream ss; - ss << "space_to_depth_opencl_kernel_" - << input->dim(0) << "_" - << input->dim(1) << "_" - << input->dim(2) << "_" - << input->dim(3); - TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); -} - -template -struct SpaceToDepthOpFunctor; -template -struct SpaceToDepthOpFunctor; - -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/space_to_depth.h b/mace/kernels/space_to_depth.h deleted file mode 100644 index b3125901549cf24f35d3d1214848b4ccabaa6229..0000000000000000000000000000000000000000 --- a/mace/kernels/space_to_depth.h +++ /dev/null @@ -1,76 +0,0 @@ -// -// Created by liutuo on 18-3-20. -// - -#ifndef MACE_KERNELS_SPACE_TO_DEPTH_H -#define MACE_KERNELS_SPACE_TO_DEPTH_H - -#include "mace/core/future.h" -#include "mace/core/tensor.h" - -namespace mace { -namespace kernels { - -template -struct SpaceToDepthOpFunctor { - explicit SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {} - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - - const int batch_size = input->dim(0); - const int input_height = input->dim(1); - const int input_width = input->dim(2); - const int input_depth = input->dim(3); - - const index_t output_depth = input_depth * block_size_ * block_size_; - const index_t output_width = input_width / block_size_; - const index_t output_height = input_height / block_size_; - - std::vector output_shape = {batch_size, output_height, output_width, output_depth}; - - output->Resize(output_shape); - - Tensor::MappingGuard logits_guard(input); - Tensor::MappingGuard output_guard(output); - const T *input_ptr = input->data(); - T *output_ptr = output->mutable_data(); - -#pragma omp parallel for - for (int b = 0; b < batch_size; ++b) { - for (int h = 0; h < input_height; ++h) { - const int out_h = h / block_size_; - const int offset_h = (h % block_size_); - for (int w = 0; w < input_width; ++w) { - const int out_w = w/ block_size_; - const int offset_w = (w % block_size_); - const int offset_d = (offset_h * block_size_ + offset_w) * input_depth; - for (int d = 0; d < input_depth; ++d) { - const int out_d = d + offset_d; - const int o_index = ((b * output_height + out_h) * output_width + out_w) * output_depth + out_d; - const int i_index = ((b * input_height + h) * input_width + w) * input_depth + d; - output_ptr[o_index] = input_ptr[i_index]; - } - } - } - } - - } - const int block_size_; -}; - -template -struct SpaceToDepthOpFunctor { - - SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {} - void operator()(const Tensor *input, Tensor *output, StatsFuture *future); - - cl::Kernel kernel_; - const int block_size_; - std::vector input_shape_; -}; - -} // namespace kernels -} // namespace mace - -#endif //MACE_KERNELS_SPACE_TO_DEPTH_H diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc index cfea8e8dbf6c3ce5e74cdb152fbb497e876d96b9..a8c4ef55bdef9dfe2c4290f7cf4e3215a852e6fb 100644 --- a/mace/ops/depth_to_space.cc +++ b/mace/ops/depth_to_space.cc @@ -19,13 +19,12 @@ void Register_DepthToSpace(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), DepthToSpaceOp); - + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), DepthToSpaceOp); - } } // namespace ops diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h index 979a08ede20911d5bc5acd14eb3d61b56f3cd0c4..78ff39191943f1cc7c215e219fcdec607d3e6718 100644 --- a/mace/ops/depth_to_space.h +++ b/mace/ops/depth_to_space.h @@ -16,33 +16,35 @@ namespace ops { template class DepthToSpaceOp : public Operator { - public: + public: DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), - functor_(OperatorBase::GetSingleArgument("block_size", 1)) {} + functor_(OperatorBase::GetSingleArgument("block_size", 1), true) {} bool Run(StatsFuture *future) override { - const Tensor *input = this->Input(INPUT); - Tensor *output = this->Output(OUTPUT); - MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); - - const int block_size = OperatorBase::GetSingleArgument("block_size", 1); - - int input_depth = input->dim(3); - MACE_CHECK(input_depth % (block_size * block_size) == 0, - "input depth should be dividable by block_size * block_size", - input->dim(3)); - functor_(input, output, future); - return true; + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); + + const int block_size = + OperatorBase::GetSingleArgument("block_size", 1); + + int input_depth = input->dim(3); + MACE_CHECK(input_depth % (block_size * block_size) == 0, + "input depth should be dividable by block_size * block_size", + input->dim(3)); + MACE_CHECK((input_depth % 4) == 0, + "input channel should be dividable by 4"); + functor_(input, output, future); + return true; } - - protected: - OP_INPUT_TAGS(INPUT); - OP_OUTPUT_TAGS(OUTPUT); - - private: - kernels::DepthToSpaceOpFunctor functor_; + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::DepthToSpaceOpFunctor functor_; }; } // namespace ops diff --git a/mace/ops/depth_to_space_benchmark.cc b/mace/ops/depth_to_space_benchmark.cc index beb1cc602a9e497a0f1473a7998f2fa3c4acdbd5..c90a8bd81c278dc5dfc3a2470097234c6dbb39f6 100644 --- a/mace/ops/depth_to_space_benchmark.cc +++ b/mace/ops/depth_to_space_benchmark.cc @@ -50,14 +50,14 @@ static void DepthToSpace( } #define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \ - static void \ + static void \ BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - DepthToSpace(iters, N, C, H, W, G); \ - } \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + DepthToSpace(iters, N, C, H, W, G); \ + } \ BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) #define BM_DEPTH_TO_SPACE(N, C, H, W, G) \ diff --git a/mace/ops/depth_to_space_test.cc b/mace/ops/depth_to_space_test.cc index 6ad830e68c8e90cc186c6c1352434fb9a8ef20b3..ba31174d5362001d5484bec51130a0a0b1f3c018 100644 --- a/mace/ops/depth_to_space_test.cc +++ b/mace/ops/depth_to_space_test.cc @@ -9,69 +9,169 @@ namespace mace { namespace ops { namespace test { -class DepthToSpaceOpTest : public OpsTestBase {}; - -TEST_F(DepthToSpaceOpTest, C8G4_CPU) { - // Construct graph +template +void RunDepthToSpace(const bool d2s, + const std::vector &input_shape, + const std::vector &input_data, + const int block_size, + const std::vector &expected_shape, + const std::vector &expected_data) { OpsTestNet net; - OpDefBuilder("DepthToSpace", "DepthToSpaceTest") - .Input("Input") - .Output("Output") - .AddIntArg("block_size", 2) - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddInputFromArray( - "Input", {1, 1, 2, 16}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - + net.AddInputFromArray("Input", input_shape, input_data); + const char *ops_name = (d2s) ? "DepthToSpace" : "SpaceToDepth"; + const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest"; + // Construct graph + if (D == DeviceType::CPU) { + OpDefBuilder(ops_name, ops_test_name) + .Input("Input") + .Output("Output") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + + } else { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder(ops_name, ops_test_name) + .Input("InputImage") + .Output("OutputImage") + .AddIntArg("block_size", block_size) + .Finalize(net.NewOperatorDef()); + } // Run - net.RunOp(); - - // Check - auto expected = CreateTensor( - {1, 2, 4, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); + net.RunOp(D); + if (D == DeviceType::OPENCL) { + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + auto expected = CreateTensor(expected_shape, expected_data); ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } -TEST_F(DepthToSpaceOpTest, C16G4_OPENCL) { - // Construct graph - OpsTestNet net; +class SpaceToDepthOpTest : public OpsTestBase {}; - // Add input data - net.AddInputFromArray( - "Input", {1, 1, 2, 16}, +TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_CPU) { + RunDepthToSpace(false, {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 1, 2, 16}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); +} - OpDefBuilder("DepthToSpace", "DepthToSpaceTest") - .Input("InputImage") - .Output("OutputImage") - .AddIntArg("block_size", 2) - .Finalize(net.NewOperatorDef()); +TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_OPENCL) { + RunDepthToSpace(false, {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); +} - // Run - net.RunOp(DeviceType::OPENCL); +TEST_F(SpaceToDepthOpTest, Input2x2x4_B2_CPU) { + RunDepthToSpace(false, {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +TEST_F(SpaceToDepthOpTest, Input4x4x1_B2_OPENCL) { + RunDepthToSpace(false, {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} - // Transfer output - ImageToBuffer(&net, "OutputImage", "Output", - kernels::BufferType::IN_OUT_CHANNEL); +class DepthToSpaceOpTest : public OpsTestBase {}; - // Check - auto expected = CreateTensor( +TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_CPU) { + RunDepthToSpace(true, {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, {1, 2, 4, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); +} - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); +TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_OPENCL) { + RunDepthToSpace(true, {1, 1, 2, 16}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}, + 2, + {1, 2, 4, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, + 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); } +TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_CPU) { + RunDepthToSpace(true, {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_OPENCL) { + RunDepthToSpace(true, {1, 1, 1, 16}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}, + 2, + {1, 2, 2, 4}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12, 13, 14, 15, 16}); +} + +/* +TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_CPU) { + + RunDepthToSpace({1, 2, 2, 3}, + {1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12}, + 2, + {1, 1, 1, 12}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_OPENCL) { + RunDepthToSpace({1, 2, 2, 6}, + {1, 2, 3, 4, 5, 6, + 7, 8, 9, 10, 11, 12 + }, + 2, + {1, 1, 1, 12}, + {1, 2, 3, 4, 5, 6, 7, 8, + 9, 10, 11, 12}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_CPU) { + + RunDepthToSpace({1, 2, 2, 2}, + {1, 10, 2, 20, 3, 30, 4, 40}, + 2, + {1, 1, 1, 8}, + {1, 10, 2, 20, 3, 30, 4, 40}); +} + +TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_OPENCL) { + + RunDepthToSpace({1, 2, 2, 2}, + {1, 10, 2, 20, 3, 30, 4, 40}, + 2, + {1, 1, 1, 8}, + {1, 10, 2, 20, 3, 30, 4, 40}); +}*/ } // namespace test } // namespace ops } // namespace mace diff --git a/mace/ops/space_to_depth.cc b/mace/ops/space_to_depth.cc index 6963007513fd4c9890c75f454f26e20c116b9c2d..55f1a13a4f80b5a88c1f318733f11db1abf2a872 100644 --- a/mace/ops/space_to_depth.cc +++ b/mace/ops/space_to_depth.cc @@ -19,13 +19,12 @@ void Register_SpaceToDepth(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), SpaceToDepthOp); - + REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), SpaceToDepthOp); - } } // namespace ops diff --git a/mace/ops/space_to_depth.h b/mace/ops/space_to_depth.h index b21eeb492b8345fa5d787c24272363b4b11f149a..517d8ccc8f8938214aefc50cfea091133d455466 100644 --- a/mace/ops/space_to_depth.h +++ b/mace/ops/space_to_depth.h @@ -9,42 +9,44 @@ #include #include "mace/core/operator.h" -#include "mace/kernels/space_to_depth.h" +#include "mace/kernels/depth_to_space.h" namespace mace { namespace ops { template class SpaceToDepthOp : public Operator { - public: + public: SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), - functor_(OperatorBase::GetSingleArgument("block_size", 1)) {} + functor_(OperatorBase::GetSingleArgument("block_size", 1), false) { + } bool Run(StatsFuture *future) override { - const Tensor *input = this->Input(INPUT); - Tensor *output = this->Output(OUTPUT); - MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); - - const int block_size = OperatorBase::GetSingleArgument("block_size", 1); - - const int input_height = input->dim(1); - const int input_width = input->dim(2); - const int input_depth = input->dim(3); - MACE_CHECK((input_width % block_size == 0) && (input_height % block_size == 0), - "input width and height should be dividable by block_size", - input->dim(3)); - functor_(input, output, future); - return true; + const Tensor *input = this->Input(INPUT); + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); + const int block_size = + OperatorBase::GetSingleArgument("block_size", 1); + const int input_height = input->dim(1); + const int input_width = input->dim(2); + const int input_depth = input->dim(3); + MACE_CHECK((input_depth % 4) == 0, + "input channel should be dividable by 4"); + MACE_CHECK( + (input_width%block_size == 0)&&(input_height%block_size == 0), + "input width and height should be dividable by block_size", + input->dim(3)); + functor_(input, output, future); + return true; } - - protected: - OP_INPUT_TAGS(INPUT); - OP_OUTPUT_TAGS(OUTPUT); - - private: - kernels::SpaceToDepthOpFunctor functor_; + protected: + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::DepthToSpaceOpFunctor functor_; }; } // namespace ops diff --git a/mace/ops/space_to_depth_benchmark.cc b/mace/ops/space_to_depth_benchmark.cc index f4d78898881f8cae8a59e8d51061e081726be308..c97028c4c85cd792769f4fd69fc19ffe9a1280c0 100644 --- a/mace/ops/space_to_depth_benchmark.cc +++ b/mace/ops/space_to_depth_benchmark.cc @@ -50,14 +50,14 @@ static void SpaceToDepth( } #define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \ - static void \ + static void \ BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - SpaceToDepth(iters, N, C, H, W, G); \ - } \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + SpaceToDepth(iters, N, C, H, W, G); \ + } \ BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) #define BM_SPACE_TO_DEPTH(N, C, H, W, G) \ diff --git a/mace/ops/space_to_depth_test.cc b/mace/ops/space_to_depth_test.cc deleted file mode 100644 index 37d020a9c8c68f7f85b574560225a456a2491eb6..0000000000000000000000000000000000000000 --- a/mace/ops/space_to_depth_test.cc +++ /dev/null @@ -1,77 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/operator.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -class SpaceToDepthOpTest : public OpsTestBase {}; - -TEST_F(SpaceToDepthOpTest, C8G4_CPU) { - // Construct graph - OpsTestNet net; - OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") - .Input("Input") - .Output("Output") - .AddIntArg("block_size", 2) - .Finalize(net.NewOperatorDef()); - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 4, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); - - // Run - net.RunOp(); - - // Check - auto expected = CreateTensor( - {1, 1, 2, 16}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -TEST_F(SpaceToDepthOpTest, C16G4_OPENCL) { - // Construct graph - OpsTestNet net; - - // Add input data - net.AddInputFromArray( - "Input", {1, 2, 4, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, - 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") - .Input("InputImage") - .Output("OutputImage") - .AddIntArg("block_size", 2) - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(DeviceType::OPENCL); - - // Transfer output - ImageToBuffer(&net, "OutputImage", "Output", - kernels::BufferType::IN_OUT_CHANNEL); - - // Check - auto expected = CreateTensor( - {1, 1, 2, 16}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, - 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); -} - -} // namespace test -} // namespace ops -} // namespace mace