提交 3501e1cc 编写于 作者: U Unknown 提交者: liutuo

combine space2depth and depth2space

上级 9f7ab38b
// //
// Created by liutuo on 18-3-20. // Copyright (c) 2017 XiaoMi All rights reserved.
// //
#ifndef MACE_KERNELS_DEPTH_TO_SPACE_H #ifndef MACE_KERNELS_DEPTH_TO_SPACE_H_
#define MACE_KERNELS_DEPTH_TO_SPACE_H #define MACE_KERNELS_DEPTH_TO_SPACE_H_
#include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct DepthToSpaceOpFunctor { struct DepthToSpaceOpFunctor {
explicit DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {} explicit DepthToSpaceOpFunctor(const int block_size, bool d2s)
void operator()(const Tensor *input, : block_size_(block_size), d2s_(d2s) {}
Tensor *output, void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
StatsFuture *future) {
std::vector<index_t> output_shape(input->shape());
const int batch_size = input->dim(0); const int batch_size = input->dim(0);
const int input_height = input->dim(1); const int input_height = input->dim(1);
const int input_width = input->dim(2); const int input_width = input->dim(2);
const int input_depth = input->dim(3); const int input_depth = input->dim(3);
const index_t output_depth = input_depth / (block_size_ * block_size_); index_t output_depth, output_width, output_height;
const index_t output_width = input_width * block_size_;
const index_t output_height = input_height * block_size_; if (d2s_) {
output_shape[0] = batch_size; output_depth = input_depth / (block_size_ * block_size_);
output_shape[1] = output_height; output_width = input_width * block_size_;
output_shape[2] = output_width; output_height = input_height * block_size_;
output_shape[3] = output_depth; } else {
output_depth = input_depth * block_size_ * block_size_;
output_width = input_width / block_size_;
output_height = input_height / block_size_;
}
std::vector<index_t> output_shape = {batch_size, output_height,
output_width, output_depth};
output->Resize(output_shape); output->Resize(output_shape);
Tensor::MappingGuard logits_guard(input); Tensor::MappingGuard logits_guard(input);
...@@ -38,41 +45,75 @@ struct DepthToSpaceOpFunctor { ...@@ -38,41 +45,75 @@ struct DepthToSpaceOpFunctor {
const T *input_ptr = input->data<T>(); const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
if (d2s_) {
#pragma omp parallel for #pragma omp parallel for
for (int b = 0; b < batch_size; ++b) { for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < output_height; ++h) { for (int h = 0; h < output_height; ++h) {
const int in_h = h / block_size_; const int in_h = h / block_size_;
const int offset_h = (h % block_size_); const int offset_h = (h % block_size_);
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
const int in_w = w / block_size_; const int in_w = w / block_size_;
const int offset_w = w % block_size_; const int offset_w = w % block_size_;
const int offset_d = (offset_h * block_size_ + offset_w) * output_depth; const int offset_d =
for (int d = 0; d < output_depth; ++d) { (offset_h * block_size_ + offset_w) * output_depth;
const int in_d = d + offset_d; for (int d = 0; d < output_depth; ++d) {
const int o_index = ((b * output_height + h) * output_width + w) * output_depth + d; const int in_d = d + offset_d;
const int i_index = ((b * input_height + in_h) * input_width + in_w) * input_depth + in_d; const int o_index =
output_ptr[o_index] = input_ptr[i_index]; ((b * output_height + h) * output_width + w) * output_depth +
d;
const int i_index =
((b * input_height + in_h) * input_width + in_w) *
input_depth +
in_d;
output_ptr[o_index] = input_ptr[i_index];
}
}
}
}
} else {
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < input_height; ++h) {
const int out_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < input_width; ++w) {
const int out_w = w / block_size_;
const int offset_w = (w % block_size_);
const int offset_d =
(offset_h * block_size_ + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d;
const int o_index =
((b * output_height + out_h) * output_width + out_w) *
output_depth +
out_d;
const int i_index =
((b * input_height + h) * input_width + w) * input_depth + d;
output_ptr[o_index] = input_ptr[i_index];
}
} }
} }
} }
} }
} }
const int block_size_; const int block_size_;
bool d2s_;
}; };
template <typename T> template <typename T>
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> { struct DepthToSpaceOpFunctor<DeviceType::OPENCL, T> {
DepthToSpaceOpFunctor(const int block_size, bool d2s)
DepthToSpaceOpFunctor(const int block_size) : block_size_(block_size) {} : block_size_(block_size), d2s_(d2s) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future); void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_; cl::Kernel kernel_;
const int block_size_; const int block_size_;
bool d2s_;
std::vector<index_t> input_shape_; std::vector<index_t> input_shape_;
}; };
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif //MACE_KERNELS_DEPTH_TO_SPACE_H #endif // MACE_KERNELS_DEPTH_TO_SPACE_H_
#include <common.h> #include <common.h>
__kernel void depth_to_space(__read_only image2d_t input, __kernel void depth_to_space(__read_only image2d_t input,
__private const int block_size, __private const int block_size,
__private const int output_depth, __private const int output_depth,
__write_only image2d_t output) { __write_only image2d_t output) {
const int out_d = get_global_id(0); const int out_d = get_global_id(0);
const int out_w = get_global_id(1); const int out_w = get_global_id(1);
const int out_h = get_global_id(2); const int out_h = get_global_id(2);
const int output_width = get_global_size(1); const int output_width = get_global_size(1);
const int out_pos = mad24(out_d, output_width, out_w); const int out_pos = mad24(out_d, output_width, out_w);
const int input_width = output_width / block_size; const int input_width = output_width / block_size;
const int in_h = out_h / block_size; const int in_h = out_h / block_size;
const int offset_h = out_h % block_size; const int offset_h = out_h % block_size;
const int in_w = out_w / block_size; const int in_w = out_w / block_size;
const int offset_w = out_w % block_size; const int offset_w = out_w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * output_depth; const int offset_d = (offset_h * block_size + offset_w) * output_depth;
const int in_d = out_d + offset_d; const int in_d = out_d + offset_d;
const int in_pos = mad24(in_d, input_width, in_w); const int in_pos = mad24(in_d, input_width, in_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
} }
__kernel void space_to_depth(__read_only image2d_t input,
__private const int block_size,
__private const int input_depth,
__write_only image2d_t output) {
const int d = get_global_id(0);
const int w = get_global_id(1);
const int h = get_global_id(2);
const int input_width = get_global_size(1);
const int in_pos = mad24(d, input_width, w);
const int output_width = input_width / block_size;
const int out_h = h / block_size;
const int offset_h = h % block_size;
const int out_w = w / block_size;
const int offset_w = w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * input_depth;
const int out_d = d + offset_d;
const int out_pos = mad24(out_d, output_width, out_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
}
#include <common.h>
__kernel void space_to_depth(__read_only image2d_t input,
__private const int block_size,
__private const int input_depth,
__write_only image2d_t output) {
const int d = get_global_id(0);
const int w = get_global_id(1);
const int h = get_global_id(2);
const int input_width = get_global_size(1);
const int in_pos = mad24(d, input_width, w);
const int output_width = input_width / block_size;
const int out_h = h / block_size;
const int offset_h = h % block_size;
const int out_w = w / block_size;
const int offset_w = w % block_size;
const int offset_d = (offset_h * block_size + offset_w) * input_depth;
const int out_d = d + offset_d;
const int out_pos = mad24(out_d, output_width, out_w);
DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h));
WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data);
}
...@@ -6,72 +6,89 @@ ...@@ -6,72 +6,89 @@
#include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
#include "mace/utils/utils.h"
namespace mace { namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()( void DepthToSpaceOpFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input, const Tensor *input, Tensor *output, StatsFuture *future) {
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t input_h = input->dim(1); const index_t input_height = input->dim(1);
const index_t input_w = input->dim(2); const index_t input_width = input->dim(2);
const index_t input_d = input->dim(3); const index_t input_depth = input->dim(3);
const index_t output_h = input_h * block_size_; index_t output_height, output_width, output_depth;
const index_t output_w = input_w * block_size_; if (d2s_) {
const index_t output_d = input_d / (block_size_ * block_size_); output_height = input_height * block_size_;
output_width = input_width * block_size_;
std::vector<index_t> output_shape = {batch, output_h, output_w, output_d}; output_depth = input_depth / (block_size_ * block_size_);
} else {
output_height = input_height / block_size_;
output_width = input_width / block_size_;
output_depth = input_depth * block_size_ * block_size_;
}
std::vector<index_t> output_shape = {batch, output_height, output_width,
output_depth};
std::vector<size_t> image_shape; std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape); CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape); output->ResizeImage(output_shape, image_shape);
const int output_depth_blocks = RoundUpDiv4(output_d); const int depth_blocks =
(d2s_) ? RoundUpDiv4(output_depth) : RoundUpDiv4(input_depth);
const char *kernel_name = (d2s_) ? "depth_to_space" : "space_to_depth";
if (kernel_.get() == nullptr) { if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global(); auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options; std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("depth_to_space"); std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
built_options.emplace("-Ddepth_to_space=" + kernel_name); std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
auto dt = DataTypeToEnum<T>::value; auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("depth_to_space", kernel_name, kernel_ =
built_options); runtime->BuildKernel("depth_to_space", kernel_name, built_options);
} }
if (!IsVecEqual(input_shape_, input->shape())) { if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0; uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, block_size_); kernel_.setArg(idx++, block_size_);
kernel_.setArg(idx++, output_depth_blocks); kernel_.setArg(idx++, depth_blocks);
kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape(); input_shape_ = input->shape();
} }
const uint32_t gws[3] = {static_cast<uint32_t>(output_depth_blocks), if (d2s_) {
static_cast<uint32_t>(output_w), const uint32_t gws[3] = {static_cast<uint32_t>(depth_blocks),
static_cast<uint32_t>(output_h * batch)}; static_cast<uint32_t>(output_width),
const std::vector<uint32_t> lws = {8, 16, 8, 1}; static_cast<uint32_t>(output_height * batch)};
std::stringstream ss; const std::vector<uint32_t> lws = {8, 16, 8, 1};
ss << "depth_to_space_opencl_kernel_" std::stringstream ss;
<< output->dim(0) << "_" ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_"
<< output->dim(1) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3);
<< output->dim(2) << "_"
<< output->dim(3); TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); } else {
const uint32_t gws[3] = {static_cast<uint32_t>(depth_blocks),
static_cast<uint32_t>(input_width),
static_cast<uint32_t>(input_height * batch)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_"
<< input->dim(1) << "_" << input->dim(2) << "_" << input->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
} }
template template struct DepthToSpaceOpFunctor<DeviceType::OPENCL, float>;
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, float>; template struct DepthToSpaceOpFunctor<DeviceType::OPENCL, half>;
template
struct DepthToSpaceOpFunctor<DeviceType::OPENCL, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
//
// Copyright (c) 2018 XiaoMi All rights reserved.
//
#include "mace/kernels/space_to_depth.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void SpaceToDepthOpFunctor<DeviceType::OPENCL, T>::operator()(
const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch_size = input->dim(0);
const index_t input_height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t input_depth = input->dim(3);
const index_t output_height = input_height / block_size_;
const index_t output_width = input_width / block_size_;
const index_t output_depth = input_depth * block_size_ * block_size_;
std::vector<index_t> output_shape = {batch_size, output_height, output_width, output_depth};
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT_CHANNEL, &image_shape);
output->ResizeImage(output_shape, image_shape);
const int input_depth_blocks = RoundUpDiv4(input_depth);
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("space_to_depth");
built_options.emplace("-Dspace_to_depth=" + kernel_name);
auto dt = DataTypeToEnum<T>::value;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
kernel_ = runtime->BuildKernel("space_to_depth", kernel_name,
built_options);
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, block_size_);
kernel_.setArg(idx++, input_depth_blocks);
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const uint32_t gws[3] = {static_cast<uint32_t>(input_depth_blocks),
static_cast<uint32_t>(input_width),
static_cast<uint32_t>(input_height * batch_size)};
const std::vector<uint32_t> lws = {8, 16, 8, 1};
std::stringstream ss;
ss << "space_to_depth_opencl_kernel_"
<< input->dim(0) << "_"
<< input->dim(1) << "_"
<< input->dim(2) << "_"
<< input->dim(3);
TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future);
}
template
struct SpaceToDepthOpFunctor<DeviceType::OPENCL, float>;
template
struct SpaceToDepthOpFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
//
// Created by liutuo on 18-3-20.
//
#ifndef MACE_KERNELS_SPACE_TO_DEPTH_H
#define MACE_KERNELS_SPACE_TO_DEPTH_H
#include "mace/core/future.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct SpaceToDepthOpFunctor {
explicit SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const int batch_size = input->dim(0);
const int input_height = input->dim(1);
const int input_width = input->dim(2);
const int input_depth = input->dim(3);
const index_t output_depth = input_depth * block_size_ * block_size_;
const index_t output_width = input_width / block_size_;
const index_t output_height = input_height / block_size_;
std::vector<index_t> output_shape = {batch_size, output_height, output_width, output_depth};
output->Resize(output_shape);
Tensor::MappingGuard logits_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for
for (int b = 0; b < batch_size; ++b) {
for (int h = 0; h < input_height; ++h) {
const int out_h = h / block_size_;
const int offset_h = (h % block_size_);
for (int w = 0; w < input_width; ++w) {
const int out_w = w/ block_size_;
const int offset_w = (w % block_size_);
const int offset_d = (offset_h * block_size_ + offset_w) * input_depth;
for (int d = 0; d < input_depth; ++d) {
const int out_d = d + offset_d;
const int o_index = ((b * output_height + out_h) * output_width + out_w) * output_depth + out_d;
const int i_index = ((b * input_height + h) * input_width + w) * input_depth + d;
output_ptr[o_index] = input_ptr[i_index];
}
}
}
}
}
const int block_size_;
};
template <typename T>
struct SpaceToDepthOpFunctor<DeviceType::OPENCL, T> {
SpaceToDepthOpFunctor(const int block_size) : block_size_(block_size) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
cl::Kernel kernel_;
const int block_size_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif //MACE_KERNELS_SPACE_TO_DEPTH_H
...@@ -19,13 +19,12 @@ void Register_DepthToSpace(OperatorRegistry *op_registry) { ...@@ -19,13 +19,12 @@ void Register_DepthToSpace(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
DepthToSpaceOp<DeviceType::OPENCL, float>); DepthToSpaceOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
DepthToSpaceOp<DeviceType::OPENCL, half>); DepthToSpaceOp<DeviceType::OPENCL, half>);
} }
} // namespace ops } // namespace ops
......
...@@ -16,33 +16,35 @@ namespace ops { ...@@ -16,33 +16,35 @@ namespace ops {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class DepthToSpaceOp : public Operator<D, T> { class DepthToSpaceOp : public Operator<D, T> {
public: public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1)) {} functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), true) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size = OperatorBase::GetSingleArgument<int>("block_size", 1); const int block_size =
OperatorBase::GetSingleArgument<int>("block_size", 1);
int input_depth = input->dim(3);
MACE_CHECK(input_depth % (block_size * block_size) == 0, int input_depth = input->dim(3);
"input depth should be dividable by block_size * block_size", MACE_CHECK(input_depth % (block_size * block_size) == 0,
input->dim(3)); "input depth should be dividable by block_size * block_size",
functor_(input, output, future); input->dim(3));
return true; MACE_CHECK((input_depth % 4) == 0,
"input channel should be dividable by 4");
functor_(input, output, future);
return true;
} }
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
}; };
} // namespace ops } // namespace ops
......
...@@ -50,14 +50,14 @@ static void DepthToSpace( ...@@ -50,14 +50,14 @@ static void DepthToSpace(
} }
#define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \ #define BM_DEPTH_TO_SPACE_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void \ static void \
BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \ mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
DepthToSpace<DEVICE, TYPE>(iters, N, C, H, W, G); \ DepthToSpace<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \ } \
BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) BENCHMARK(BM_DEPTH_TO_SPACE_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_DEPTH_TO_SPACE(N, C, H, W, G) \ #define BM_DEPTH_TO_SPACE(N, C, H, W, G) \
......
...@@ -9,69 +9,169 @@ namespace mace { ...@@ -9,69 +9,169 @@ namespace mace {
namespace ops { namespace ops {
namespace test { namespace test {
class DepthToSpaceOpTest : public OpsTestBase {}; template <DeviceType D>
void RunDepthToSpace(const bool d2s,
TEST_F(DepthToSpaceOpTest, C8G4_CPU) { const std::vector<index_t> &input_shape,
// Construct graph const std::vector<float> &input_data,
const int block_size,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data) {
OpsTestNet net; OpsTestNet net;
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") net.AddInputFromArray<D, float>("Input", input_shape, input_data);
.Input("Input") const char *ops_name = (d2s) ? "DepthToSpace" : "SpaceToDepth";
.Output("Output") const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest";
.AddIntArg("block_size", 2) // Construct graph
.Finalize(net.NewOperatorDef()); if (D == DeviceType::CPU) {
OpDefBuilder(ops_name, ops_test_name)
// Add input data .Input("Input")
net.AddInputFromArray<DeviceType::CPU, float>( .Output("Output")
"Input", {1, 1, 2, 16}, .AddIntArg("block_size", block_size)
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, .Finalize(net.NewOperatorDef());
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
} else {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder(ops_name, ops_test_name)
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("block_size", block_size)
.Finalize(net.NewOperatorDef());
}
// Run // Run
net.RunOp(); net.RunOp(D);
// Check
auto expected = CreateTensor<float>(
{1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
if (D == DeviceType::OPENCL) {
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
} }
TEST_F(DepthToSpaceOpTest, C16G4_OPENCL) { class SpaceToDepthOpTest : public OpsTestBase {};
// Construct graph
OpsTestNet net;
// Add input data TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_CPU) {
net.AddInputFromArray<DeviceType::OPENCL, float>( RunDepthToSpace<DeviceType::CPU>(false, {1, 2, 4, 4},
"Input", {1, 1, 2, 16}, {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
BufferToImage<DeviceType::OPENCL, float>(&net, "Input", "InputImage", }
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_OPENCL) {
.Input("InputImage") RunDepthToSpace<DeviceType::OPENCL>(false, {1, 2, 4, 4},
.Output("OutputImage") {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
.AddIntArg("block_size", 2) 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31},
.Finalize(net.NewOperatorDef()); 2,
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
}
// Run TEST_F(SpaceToDepthOpTest, Input2x2x4_B2_CPU) {
net.RunOp(DeviceType::OPENCL); RunDepthToSpace<DeviceType::CPU>(false, {1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
TEST_F(SpaceToDepthOpTest, Input4x4x1_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(false, {1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
// Transfer output class DepthToSpaceOpTest : public OpsTestBase {};
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
// Check TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_CPU) {
auto expected = CreateTensor<float>( RunDepthToSpace<DeviceType::CPU>(true, {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 2, 4, 4}, {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23, {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}); 8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
}
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); TEST_F(DepthToSpaceOpTest, Input1x2x16_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(true, {1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31},
2,
{1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
} }
TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>(true, {1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>(true, {1, 1, 1, 16},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16},
2,
{1, 2, 2, 4},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16});
}
/*
TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>({1, 2, 2, 3},
{1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12},
2,
{1, 1, 1, 12},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12});
}
TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>({1, 2, 2, 6},
{1, 2, 3, 4, 5, 6,
7, 8, 9, 10, 11, 12
},
2,
{1, 1, 1, 12},
{1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12});
}
TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_CPU) {
RunDepthToSpace<DeviceType::CPU>({1, 2, 2, 2},
{1, 10, 2, 20, 3, 30, 4, 40},
2,
{1, 1, 1, 8},
{1, 10, 2, 20, 3, 30, 4, 40});
}
TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_OPENCL) {
RunDepthToSpace<DeviceType::OPENCL>({1, 2, 2, 2},
{1, 10, 2, 20, 3, 30, 4, 40},
2,
{1, 1, 1, 8},
{1, 10, 2, 20, 3, 30, 4, 40});
}*/
} // namespace test } // namespace test
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
...@@ -19,13 +19,12 @@ void Register_SpaceToDepth(OperatorRegistry *op_registry) { ...@@ -19,13 +19,12 @@ void Register_SpaceToDepth(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SpaceToDepthOp<DeviceType::OPENCL, float>); SpaceToDepthOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth") REGISTER_OPERATOR(op_registry, OpKeyBuilder("SpaceToDepth")
.Device(DeviceType::OPENCL) .Device(DeviceType::OPENCL)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
SpaceToDepthOp<DeviceType::OPENCL, half>); SpaceToDepthOp<DeviceType::OPENCL, half>);
} }
} // namespace ops } // namespace ops
......
...@@ -9,42 +9,44 @@ ...@@ -9,42 +9,44 @@
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/space_to_depth.h" #include "mace/kernels/depth_to_space.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class SpaceToDepthOp : public Operator<D, T> { class SpaceToDepthOp : public Operator<D, T> {
public: public:
SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws) SpaceToDepthOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("block_size", 1)) {} functor_(OperatorBase::GetSingleArgument<int>("block_size", 1), false) {
}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT); Tensor *output = this->Output(OUTPUT);
MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); MACE_CHECK(input->dim_size() == 4, "input dim should be 4");
const int block_size =
const int block_size = OperatorBase::GetSingleArgument<int>("block_size", 1); OperatorBase::GetSingleArgument<int>("block_size", 1);
const int input_height = input->dim(1);
const int input_height = input->dim(1); const int input_width = input->dim(2);
const int input_width = input->dim(2); const int input_depth = input->dim(3);
const int input_depth = input->dim(3); MACE_CHECK((input_depth % 4) == 0,
MACE_CHECK((input_width % block_size == 0) && (input_height % block_size == 0), "input channel should be dividable by 4");
"input width and height should be dividable by block_size", MACE_CHECK(
input->dim(3)); (input_width%block_size == 0)&&(input_height%block_size == 0),
functor_(input, output, future); "input width and height should be dividable by block_size",
return true; input->dim(3));
functor_(input, output, future);
return true;
} }
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::SpaceToDepthOpFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
}; };
} // namespace ops } // namespace ops
......
...@@ -50,14 +50,14 @@ static void SpaceToDepth( ...@@ -50,14 +50,14 @@ static void SpaceToDepth(
} }
#define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \ #define BM_SPACE_TO_DEPTH_MACRO(N, C, H, W, G, TYPE, DEVICE) \
static void \ static void \
BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \ BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \ const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \ mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
SpaceToDepth<DEVICE, TYPE>(iters, N, C, H, W, G); \ SpaceToDepth<DEVICE, TYPE>(iters, N, C, H, W, G); \
} \ } \
BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE) BENCHMARK(BM_SPACE_TO_DEPTH_##N##_##C##_##H##_##W##_##G##_##TYPE##_##DEVICE)
#define BM_SPACE_TO_DEPTH(N, C, H, W, G) \ #define BM_SPACE_TO_DEPTH(N, C, H, W, G) \
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class SpaceToDepthOpTest : public OpsTestBase {};
TEST_F(SpaceToDepthOpTest, C8G4_CPU) {
// Construct graph
OpsTestNet net;
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("Input")
.Output("Output")
.AddIntArg("block_size", 2)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>(
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(SpaceToDepthOpTest, C16G4_OPENCL) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::OPENCL, float>(
"Input", {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23,
8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31});
BufferToImage<DeviceType::OPENCL, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("block_size", 2)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(DeviceType::OPENCL);
// Transfer output
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
// Check
auto expected = CreateTensor<float>(
{1, 1, 2, 16},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册