diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 38a220eb4b2dcf1cf2091bc9a4f38742648cc977..60eabfdaa0a8de900690e2ef87e9719ead07d639 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -73,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); @@ -109,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); + ops::Register_CWise(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); ops::Register_Eltwise(this); diff --git a/mace/kernels/cwise.h b/mace/kernels/cwise.h new file mode 100644 index 0000000000000000000000000000000000000000..073f5c48dbbd6d576acc2e9c39492b7522af2b38 --- /dev/null +++ b/mace/kernels/cwise.h @@ -0,0 +1,123 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_CWISE_H_ +#define MACE_KERNELS_CWISE_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +enum CWiseType { + MUL = 0, + ADD = 1, + MAX = 2, + MIN = 3, + SUB = 4, + DIV = 5, + NEG = 6, + ABS = 7, +}; + +struct CWiseFunctorBase { + CWiseFunctorBase(const CWiseType type, const float coeff) + : type_(type), coeff_(coeff) {} + + CWiseType type_; + float coeff_; +}; + +template +struct CWiseFunctor : CWiseFunctorBase { + CWiseFunctor(const CWiseType type, const float coeff) + : CWiseFunctorBase(type, coeff) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + const index_t size = input->size(); + + switch (type_) { + case MUL: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = coeff_ * input_ptr[i]; + } + break; + case ADD: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = coeff_ + input_ptr[i]; + } + break; + case MAX: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = std::max(input_ptr[i], coeff_); + } + break; + case MIN: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = std::min(input_ptr[i], coeff_); + } + break; + case SUB: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input_ptr[i] - coeff_; + } + break; + case DIV: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input_ptr[i] / coeff_; + } + break; + case NEG: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = 0 - input_ptr[i]; + } + break; + case ABS: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + T val = input_ptr[i]; + output_ptr[i] = (val > 0)? val : 0 - val; + } + break; + default: + LOG(FATAL) << "CWise op not support type " << type_; + } + } +}; + +template +struct CWiseFunctor : CWiseFunctorBase { + CWiseFunctor(const CWiseType type, const float coeff) + : CWiseFunctorBase(type, coeff) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_CWISE_H_ diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 9c7f0a901a5968f1d0f4cf5c7af8ceeebb465f7e..423a8f9fdf59abba9f3c92acdcd5aa0ba5ca40f2 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -19,6 +19,7 @@ enum EltwiseType { SUM = 1, MAX = 2, MIN = 3, + SUB = 4, }; struct EltwiseFunctorBase { @@ -80,6 +81,12 @@ struct EltwiseFunctor : EltwiseFunctorBase { output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); } break; + case SUB: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input0_ptr[i] - input1_ptr[i]; + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type_; } diff --git a/mace/kernels/opencl/cl/cwise.cl b/mace/kernels/opencl/cl/cwise.cl new file mode 100644 index 0000000000000000000000000000000000000000..16f1f0851f98abdae95fb936a7ee6c4f449d0b96 --- /dev/null +++ b/mace/kernels/opencl/cl/cwise.cl @@ -0,0 +1,42 @@ +#include + +__kernel void cwise(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __private const float value, + __write_only image2d_t output) { + const int w = get_global_id(0); + const int hb = get_global_id(1); + + DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); + DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value}; + DATA_TYPE4 out; + +#if CWISE_TYPE == 0 + out = in0 * in1; +#elif CWISE_TYPE == 1 + out = in0 + in1; +#elif CWISE_TYPE == 2 + out.x = fmax(in0.x, value); + out.y = fmax(in0.y, value); + out.z = fmax(in0.z, value); + out.z = fmax(in0.w, value); +#elif CWISE_TYPE == 3 + out.x = fmin(in0.x, value); + out.y = fmin(in0.y, value); + out.z = fmin(in0.z, value); + out.z = fmin(in0.w, value); +#elif CWISE_TYPE == 4 + out = in0 - in1; +#elif CWISE_TYPE == 5 + out = in0 / in1; +#elif CWISE_TYPE == 6 + in1 = (DATA_TYPE4)(0, 0, 0, 0); + out = in1 - in0; +#elif CWISE_TYPE == 7 + out.x = fabs(in0.x); + out.y = fabs(in0.y); + out.z = fabs(in0.z); + out.w = fabs(in0.w); +#endif + + WRITE_IMAGET(output, (int2)(w, hb), out); +} diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl index a52617c87367635f697fc29f7c56315b6347bf13..21045ec94fe2d2eac962294fb09bdf3041e20e49 100644 --- a/mace/kernels/opencl/cl/depth_to_space.cl +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -4,36 +4,33 @@ __kernel void depth_to_space( UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 __read_only image2d_t input, __private const int block_size, - __private const int output_depth, + __private const int input_height, + __private const int input_width, + __private const int input_depth_blocks, + __private const int output_height, + __private const int output_width, + __private const int output_depth_blocks, __write_only image2d_t output) { const int out_d = get_global_id(0); const int out_w = get_global_id(1); const int out_h = get_global_id(2); -#ifndef NON_UNIFORM_WORK_GROUP - if (out_d >= global_size_dim0 || out_w >= global_size_dim1 - || out_h >= global_size_dim2) { + if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width) return; - } - const int output_width = global_size_dim1; -#else - const int output_width = get_global_size(1); -#endif const int out_pos = mad24(out_d, output_width, out_w); - const int input_width = output_width / block_size; - const int in_h = out_h / block_size; const int offset_h = out_h % block_size; const int in_w = out_w / block_size; const int offset_w = out_w % block_size; - - const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int offset_d = (offset_h * block_size + offset_w) * output_depth_blocks; const int in_d = out_d + offset_d; - const int in_pos = mad24(in_d, input_width, in_w); + if (in_h >= input_height || in_w >= input_width || in_d >= input_depth_blocks) + return; + const int in_pos = mad24(in_d, input_width, in_w); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); } @@ -42,35 +39,34 @@ __kernel void space_to_depth( UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 __read_only image2d_t input, __private const int block_size, - __private const int input_depth, + __private const int input_height, + __private const int input_width, + __private const int input_depth_blocks, + __private const int output_height, + __private const int output_width, + __private const int output_depth_blocks, __write_only image2d_t output) { const int d = get_global_id(0); const int w = get_global_id(1); const int h = get_global_id(2); -#ifndef NON_UNIFORM_WORK_GROUP - if (d >= global_size_dim0 || w >= global_size_dim1 - || h >= global_size_dim2) { + if (h >= input_height || w >= input_width || d >= input_depth_blocks) return; - } - const int input_width = global_size_dim1; -#else - const int input_width = get_global_size(1); -#endif const int in_pos = mad24(d, input_width, w); - const int output_width = input_width / block_size; const int out_h = h / block_size; const int offset_h = h % block_size; const int out_w = w / block_size; const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int offset_d = (offset_h * block_size + offset_w) * input_depth_blocks; const int out_d = d + offset_d; - const int out_pos = mad24(out_d, output_width, out_w); - DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); + if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width) + return; + const int out_pos = mad24(out_d, output_width, out_w); + DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); } diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index def21f0a993b75d321729e5c89b080555c1dcdf7..8509dc38286454d26ae46d85f82407a9c346e84f 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -33,8 +33,9 @@ __kernel void eltwise( out = fmax(in0, in1); #elif ELTWISE_TYPE == 3 out = fmin(in0, in1); +#elif ELTWISE_TYPE == 4 + out = in0 - in1; #endif WRITE_IMAGET(output, (int2)(w, hb), out); } - diff --git a/mace/kernels/opencl/cwise_opencl.cc b/mace/kernels/opencl/cwise_opencl.cc new file mode 100644 index 0000000000000000000000000000000000000000..bd839c556ede14ffce77e689f0d9476f3134e40e --- /dev/null +++ b/mace/kernels/opencl/cwise_opencl.cc @@ -0,0 +1,57 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/cwise.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void CWiseFunctor::operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t width_pixels = channel_blocks * width; + const index_t batch_height_pixels = batch * height; + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise"); + built_options.emplace("-Dcwise=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + built_options.emplace(MakeString("-DCWISE_TYPE=", type_)); + kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, static_cast(coeff_)); + kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input->shape(); + } + + const uint32_t gws[2] = {static_cast(width_pixels), + static_cast(batch_height_pixels)}; + const std::vector lws = {64, 16, 1}; + std::stringstream ss; + ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) + << "_" << output->dim(2) << "_" << output->dim(3); + TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); +} + +template struct CWiseFunctor; +template struct CWiseFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc index 8fc0924704badcf1f37d9a55b8c0188e65b295de..1c0624365c96eb19f08a22f9055d75834a4d6b72 100644 --- a/mace/kernels/opencl/depth_to_space_opencl.cc +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -20,7 +20,6 @@ void DepthToSpaceOpFunctor::operator()( const index_t input_width = input->dim(2); const index_t input_depth = input->dim(3); - int depth_blocks = 1; const char *kernel_name = nullptr; index_t output_height, output_width, output_depth; @@ -28,15 +27,15 @@ void DepthToSpaceOpFunctor::operator()( output_height = input_height * block_size_; output_width = input_width * block_size_; output_depth = input_depth / (block_size_ * block_size_); - depth_blocks = RoundUpDiv4(output_depth); kernel_name = "depth_to_space"; } else { output_height = input_height / block_size_; output_width = input_width / block_size_; output_depth = input_depth * block_size_ * block_size_; - depth_blocks = RoundUpDiv4(input_depth); kernel_name = "space_to_depth"; } + const index_t input_depth_blocks = RoundUpDiv4(input_depth); + const index_t output_depth_blocks = RoundUpDiv4(output_depth); std::vector output_shape = {batch, output_height, output_width, output_depth}; @@ -54,13 +53,14 @@ void DepthToSpaceOpFunctor::operator()( kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name; built_options.emplace(kernel_name_ss.str()); auto dt = DataTypeToEnum::value; - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt)); if (runtime->IsNonUniformWorkgroupsSupported()) { built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); } kernel_ = - runtime->BuildKernel("depth_to_space", kernel_name, built_options); + runtime->BuildKernel("depth_to_space", + obfuscated_kernel_name, built_options); kwg_size_ = static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); @@ -70,13 +70,13 @@ void DepthToSpaceOpFunctor::operator()( std::stringstream ss; if (!IsVecEqual(input_shape_, input->shape())) { if (d2s_) { - gws[0] = static_cast(depth_blocks); + gws[0] = static_cast(output_depth_blocks); gws[1] = static_cast(output_width); gws[2] = static_cast(output_height * batch); ss << "depth_to_space_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); } else { - gws[0] = static_cast(depth_blocks); + gws[0] = static_cast(input_depth_blocks); gws[1] = static_cast(input_width); gws[2] = static_cast(input_height * batch); ss << "space_to_depth_opencl_kernel_" << input->dim(0) << "_" @@ -90,8 +90,13 @@ void DepthToSpaceOpFunctor::operator()( kernel_.setArg(idx++, gws[2]); } kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, block_size_); - kernel_.setArg(idx++, depth_blocks); + kernel_.setArg(idx++, static_cast(block_size_)); + kernel_.setArg(idx++, static_cast(input_height)); + kernel_.setArg(idx++, static_cast(input_width)); + kernel_.setArg(idx++, static_cast(input_depth_blocks)); + kernel_.setArg(idx++, static_cast(output_height)); + kernel_.setArg(idx++, static_cast(output_width)); + kernel_.setArg(idx++, static_cast(output_depth_blocks)); kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/space_to_batch_opencl.cc b/mace/kernels/opencl/space_to_batch_opencl.cc index 31b5013b737335c40255d9d4163e1d2fb8572d68..b4ae998a56f34cf3f23266e549f7346360ae4113 100644 --- a/mace/kernels/opencl/space_to_batch_opencl.cc +++ b/mace/kernels/opencl/space_to_batch_opencl.cc @@ -51,7 +51,8 @@ void SpaceToBatchFunctor::operator()( built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); } kernel_ = - runtime->BuildKernel("space_to_batch", kernel_name, built_options); + runtime->BuildKernel("space_to_batch", + obfuscated_kernel_name, built_options); kwg_size_ = static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index 08f1bab24ea1b6b8b19faf34d57d38de0fcfb71b..528f1e1f67c32b037262e0019847eb25fdc62a4c 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -31,7 +31,6 @@ class Conv2dOp : public ConvPool2dOpBase { const Tensor *filter = this->Input(FILTER); const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; Tensor *output = this->Output(OUTPUT); - functor_(input, filter, bias, output, future); return true; diff --git a/mace/ops/cwise.cc b/mace/ops/cwise.cc new file mode 100644 index 0000000000000000000000000000000000000000..42439c3e254837638cd3a919a64e782f5a227a17 --- /dev/null +++ b/mace/ops/cwise.cc @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/cwise.h" + +namespace mace { +namespace ops { + +void Register_CWise(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + CWiseOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + CWiseOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + CWiseOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/cwise.h b/mace/ops/cwise.h new file mode 100644 index 0000000000000000000000000000000000000000..75430183c43f371467e947968f65f958954ce1c2 --- /dev/null +++ b/mace/ops/cwise.h @@ -0,0 +1,49 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_CWISE_H_ +#define MACE_OPS_CWISE_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/cwise.h" + +namespace mace { +namespace ops { + +template +class CWiseOp : public Operator { + public: + CWiseOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + x_(OperatorBase::GetSingleArgument("x", 1.0)), + functor_(static_cast( + OperatorBase::GetSingleArgument( + "type", static_cast( + kernels::CWiseType::ADD))), + this->x_) {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(INPUT); + Tensor *output_tensor = this->Output(OUTPUT); + output_tensor->ResizeLike(input_tensor); + + functor_(input_tensor, output_tensor, future); + return true; + } + + protected: + const float x_; + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::CWiseFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_CWISE_H_ diff --git a/mace/ops/cwise_benchmark.cc b/mace/ops/cwise_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..6ab6aa543321f5102b731a735762eb2aaa97ca39 --- /dev/null +++ b/mace/ops/cwise_benchmark.cc @@ -0,0 +1,93 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void CWise(int iters, int batch, int channels, + int height, int width, float x, int type) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("CWise", "CWiseBM") + .Input("InputImage") + .Output("Output") + .AddIntArg("type", type) + .AddFloatArg("x", x) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("CWise", "CWiseBM") + .Input("Input") + .Output("Output") + .AddIntArg("type", type) + .AddFloatArg("x", x) + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ + static void \ + BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + CWise(iters, N, C, H, W, X, G); \ + } \ + BENCHMARK( \ + BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) + +#define BM_CWISE(N, C, H, W, X, G) \ + BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \ + BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \ + BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL); + +BM_CWISE(1, 1, 512, 512, 2, 0); +BM_CWISE(1, 3, 128, 128, 2, 1); +BM_CWISE(1, 3, 512, 512, 2, 4); +BM_CWISE(1, 32, 112, 112, 2, 5); +BM_CWISE(1, 32, 112, 112, 2, 6); +BM_CWISE(1, 32, 112, 112, 2, 7); +BM_CWISE(1, 64, 256, 256, 3, 0); +BM_CWISE(1, 64, 512, 512, 3, 1); +BM_CWISE(1, 128, 56, 56, 3, 4); +BM_CWISE(1, 128, 256, 256, 3, 5); +BM_CWISE(1, 64, 512, 512, 3, 6); +BM_CWISE(1, 64, 512, 512, 3, 7); +BM_CWISE(1, 256, 14, 14, 3, 0); +BM_CWISE(1, 512, 14, 14, 3, 1); +BM_CWISE(1, 1024, 7, 7, 3, 4); +BM_CWISE(32, 1, 256, 256, 3, 5); +BM_CWISE(32, 1, 256, 256, 3, 6); +BM_CWISE(32, 1, 256, 256, 3, 7); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/cwise_test.cc b/mace/ops/cwise_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bd934f8d55c67dd8da8f91678838fe8c5b84bf9 --- /dev/null +++ b/mace/ops/cwise_test.cc @@ -0,0 +1,176 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "../kernels/cwise.h" + +namespace mace { +namespace ops { +namespace test { + +class CWiseOpTest : public OpsTestBase {}; + + +template +void Simple(const kernels::CWiseType type, + const std::vector &shape, + const std::vector &input0, + const float x, + const std::vector &output) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input1", shape, input0); + + if (D == DeviceType::CPU) { + OpDefBuilder("CWise", "CWiseTest") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } else { + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("CWise", "CWiseTest") + .Input("InputImg1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + + auto expected = CreateTensor(shape, output); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-3); +} + +TEST_F(CWiseOpTest, CPUSimple) { + Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); + + Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); + + Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); + + Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); + + Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); + + Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, + {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); +} + +TEST_F(CWiseOpTest, GPUSimple) { + Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); + + Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); + + Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); + + Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); + + Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); + + Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, + {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); +} + +template +void RandomTest(const kernels::CWiseType type, + const std::vector &shape) { + testing::internal::LogToStderr(); + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input1", shape); + + OpDefBuilder("CWise", "CWiseTest") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", 1.2) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("CWise", "CWiseTest") + .Input("InputImg1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", 1.2) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImg", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-3); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-1); + } +} + +TEST_F(CWiseOpTest, OPENCLRandomFloat) { + RandomTest(kernels::CWiseType::MUL, + {3, 23, 37, 19}); + RandomTest(kernels::CWiseType::ADD, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::SUB, + {3, 32, 32, 64}); + RandomTest(kernels::CWiseType::DIV, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::NEG, + {13, 32, 32, 64}); +} + +TEST_F(CWiseOpTest, OPENCLRandomHalf) { + RandomTest(kernels::CWiseType::MUL, + {3, 23, 37, 19}); + RandomTest(kernels::CWiseType::ADD, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::SUB, + {3, 32, 32, 64}); + RandomTest(kernels::CWiseType::DIV, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::NEG, + {13, 32, 32, 64}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h index 78ff39191943f1cc7c215e219fcdec607d3e6718..ad71396fa390df91d3f27898763d228a8a82f7eb 100644 --- a/mace/ops/depth_to_space.h +++ b/mace/ops/depth_to_space.h @@ -19,18 +19,16 @@ class DepthToSpaceOp : public Operator { public: DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), - functor_(OperatorBase::GetSingleArgument("block_size", 1), true) {} + block_size_(OperatorBase::GetSingleArgument("block_size", 1)), + functor_(this->block_size_, true) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); Tensor *output = this->Output(OUTPUT); MACE_CHECK(input->dim_size() == 4, "input dim should be 4"); - const int block_size = - OperatorBase::GetSingleArgument("block_size", 1); - int input_depth = input->dim(3); - MACE_CHECK(input_depth % (block_size * block_size) == 0, + MACE_CHECK(input_depth % (block_size_ * block_size_) == 0, "input depth should be dividable by block_size * block_size", input->dim(3)); MACE_CHECK((input_depth % 4) == 0, @@ -40,6 +38,7 @@ class DepthToSpaceOp : public Operator { } protected: + const int block_size_; OP_INPUT_TAGS(INPUT); OP_OUTPUT_TAGS(OUTPUT); diff --git a/mace/ops/depth_to_space_test.cc b/mace/ops/depth_to_space_test.cc index ba31174d5362001d5484bec51130a0a0b1f3c018..835e39b3cc6dada31ac72f89f1d756a6ea430ddd 100644 --- a/mace/ops/depth_to_space_test.cc +++ b/mace/ops/depth_to_space_test.cc @@ -1,7 +1,9 @@ // // Copyright (c) 2017 XiaoMi All rights reserved. // +#include +#include #include "mace/core/operator.h" #include "mace/ops/ops_test_util.h" @@ -48,6 +50,7 @@ void RunDepthToSpace(const bool d2s, ExpectTensorNear(*expected, *net.GetOutput("Output"), 0.001); } + class SpaceToDepthOpTest : public OpsTestBase {}; TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_CPU) { @@ -70,6 +73,8 @@ TEST_F(SpaceToDepthOpTest, Input2x4x4_B2_OPENCL) { 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31}); } + + TEST_F(SpaceToDepthOpTest, Input2x2x4_B2_CPU) { RunDepthToSpace(false, {1, 2, 2, 4}, {1, 2, 3, 4, 5, 6, 7, 8, @@ -132,46 +137,83 @@ TEST_F(DepthToSpaceOpTest, Input1x1x16_B2_OPENCL) { 9, 10, 11, 12, 13, 14, 15, 16}); } -/* -TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_CPU) { - RunDepthToSpace({1, 2, 2, 3}, - {1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12}, - 2, - {1, 1, 1, 12}, - {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12}); +TEST_F(DepthToSpaceOpTest, InputLarger_B2_OPENCL) { + const std::vector in = std::vector(192 * 192 *128, 1.0); + + RunDepthToSpace(true, {1, 192, 192, 128}, + in, + 2, + {1, 384, 384, 32}, + in); +} + + +template +void RandomTest(const bool d2s, const int block_size, + const std::vector &shape) { + testing::internal::LogToStderr(); + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + + const char *ops_name = (d2s) ? "DepthToSpace" : "SpaceToDepth"; + const char *ops_test_name = (d2s) ? "DepthToSpaceTest" : "SpaceToDepthTest"; + + // Add input data + net.AddRandomInput("Input1", shape); + + OpDefBuilder(ops_name, ops_test_name) + .Input("Input1") + .AddIntArg("block_size", block_size) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder(ops_name, ops_test_name) + .Input("InputImg1") + .AddIntArg("block_size", block_size) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImg", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-3); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-1); + } } -TEST_F(DepthToSpaceOpTest, Input2x2x3_B2_OPENCL) { - RunDepthToSpace({1, 2, 2, 6}, - {1, 2, 3, 4, 5, 6, - 7, 8, 9, 10, 11, 12 - }, - 2, - {1, 1, 1, 12}, - {1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12}); +TEST_F(DepthToSpaceOpTest, OPENCLRandomFloat) { + RandomTest(true, 2, {1, 192, 192, 128}); } -TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_CPU) { +TEST_F(DepthToSpaceOpTest, OPENCLRandomHalf) { +RandomTest(true, 2, {1, 192, 192, 128}); +} - RunDepthToSpace({1, 2, 2, 2}, - {1, 10, 2, 20, 3, 30, 4, 40}, - 2, - {1, 1, 1, 8}, - {1, 10, 2, 20, 3, 30, 4, 40}); +TEST_F(SpaceToDepthOpTest, OPENCLRandomFloat) { +RandomTest(false, 2, {1, 384, 384, 32}); } -TEST_F(DepthToSpaceOpTest, Input2x2x2_B2_OPENCL) { +TEST_F(SpaceToDepthOpTest, OPENCLRandomHalf) { +RandomTest(false, 2, {1, 384, 384, 32}); +} - RunDepthToSpace({1, 2, 2, 2}, - {1, 10, 2, 20, 3, 30, 4, 40}, - 2, - {1, 1, 1, 8}, - {1, 10, 2, 20, 3, 30, 4, 40}); -}*/ } // namespace test } // namespace ops } // namespace mace diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 01e73645e8439d77ba5b4b6bada8f84e7c3eae9a..93fbea92a8f283a3646e3746ad496bf35a050f15 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -19,6 +19,21 @@ pooling_type_mode = { 'MaxPool': 2 } +# the order should be the same as +# eltwise type's in mace/kernels/eltwise.h +# and also cwise type's in mace/kernels/cwise.h +# cuz these math ops should have compatible with "EltWise" and "CWise" +math_type_mode = { + 'MUL': 0, + 'ADD': 1, + 'MAX': 2, + 'MIN': 3, + 'SUB': 4, + 'DIV': 5, + 'NEG': 6, + 'ABS': 7 +} + buffer_type_map = { 'CONV2D_FILTER' : 0, 'IN_OUT_CHANNEL' : 1, @@ -622,6 +637,59 @@ class TFConverter(object): self.add_output_shape(op.outputs, op_def) self.resolved_ops[op.name] = 1 self.unused_tensor.add(get_input_tensor(op, 1).name) + + def convert_math(self, op, math_type): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + + if len(op.inputs) == 1: + op_def.type = "CWise" + op_def.input.extend([input.name for input in op.inputs]) + x_arg = op_def.arg.add() + x_arg.name = 'x' + x_arg.f = 0 + elif len(op.inputs) >= 2: + input_tensor0 = get_input_tensor(op, 0) + input_tensor1 = get_input_tensor(op, 1) + if input_tensor0.shape == input_tensor1.shape: + op_def.type = "Eltwise" + op_def.input.extend([input.name for input in op.inputs]) + else: + op_def.type = "CWise" + x_value = 0 + if len(input_tensor1.shape)==4: + op_def.input.extend([op.inputs[1].name]) + x_value = get_input_tensor(op, 0).eval().astype(np.float32) + else: + op_def.input.extend([op.inputs[0].name]) + x_value = get_input_tensor(op, 1).eval().astype(np.float32) + x_arg = op_def.arg.add() + x_arg.name = 'x' + x_arg.f = x_value + type_arg = op_def.arg.add() + type_arg.name = 'type' + type_arg.i = math_type_mode[math_type] + op_def.output.extend([output.name for output in op.outputs]) + self.add_output_shape(op.outputs, op_def) + self.resolved_ops[op.name] = 1 + + def convert_depth_to_space(self, op, d2s): + op_def = self.net_def.op.add() + arg = op_def.arg.add() + arg.name = 'T' + arg.i = self.dt + op_def.name = op.name + op_def.type = op.type + op_def.input.extend([op.inputs[0].name]) + op_def.output.extend([output.name for output in op.outputs]) + size_arg = op_def.arg.add() + size_arg.name = 'block_size' + size_arg.i = op.get_attr('block_size') + self.add_output_shape(op.outputs, op_def) + self.resolved_ops[op.name] = 1 def convert_bias_add(self, op): op_def = mace_pb2.OperatorDef() @@ -850,6 +918,16 @@ class TFConverter(object): self.convert_space_to_batch(op, False) elif op.type == 'BatchToSpaceND': self.convert_space_to_batch(op, True) + elif op.type == 'DepthToSpace': + self.convert_depth_to_space(op, True) + elif op.type == 'SpaceToDepth': + self.convert_depth_to_space(op, False) + elif op.type in ['Neg', 'neg', 'Negative', 'negative']: + self.convert_math(op, 'NEG') + elif op.type == 'Mul': + self.convert_math(op, 'MUL') + elif op.type == 'Sub': + self.convert_math(op, 'SUB') elif self.is_softmax(op): self.convert_softmax(op) elif op.type in ['Relu', 'Sigmoid', 'Tanh']: