From 26d9b5cbca6147dba16211b8beecdb70f24857f9 Mon Sep 17 00:00:00 2001 From: Unknown Date: Wed, 28 Mar 2018 16:28:13 +0800 Subject: [PATCH] Support super resolution & fix d2s opencl bugs add neg/scalar_math --- mace/core/operator.cc | 4 + mace/kernels/eltwise.h | 12 +- mace/kernels/negative.h | 69 ++++++++ mace/kernels/opencl/cl/neg.cl | 14 ++ mace/kernels/opencl/cl/scalar_math.cl | 27 ++++ mace/kernels/opencl/depth_to_space_opencl.cc | 3 + mace/kernels/opencl/eltwise_opencl.cc | 2 +- mace/kernels/opencl/negative_opencl.cc | 65 ++++++++ mace/kernels/opencl/scalar_math_opencl.cc | 57 +++++++ mace/kernels/scalar_math.h | 96 +++++++++++ mace/ops/neg.cc | 31 ++++ mace/ops/neg.h | 39 +++++ mace/ops/neg_benchmark.cc | 82 ++++++++++ mace/ops/neg_test.cc | 61 +++++++ mace/ops/scalar_math.cc | 31 ++++ mace/ops/scalar_math.h | 49 ++++++ mace/ops/scalar_math_benchmark.cc | 88 ++++++++++ mace/ops/scalar_math_test.cc | 160 +++++++++++++++++++ 18 files changed, 883 insertions(+), 7 deletions(-) create mode 100644 mace/kernels/negative.h create mode 100644 mace/kernels/opencl/cl/neg.cl create mode 100644 mace/kernels/opencl/cl/scalar_math.cl create mode 100644 mace/kernels/opencl/negative_opencl.cc create mode 100644 mace/kernels/opencl/scalar_math_opencl.cc create mode 100644 mace/kernels/scalar_math.h create mode 100644 mace/ops/neg.cc create mode 100644 mace/ops/neg.h create mode 100644 mace/ops/neg_benchmark.cc create mode 100644 mace/ops/neg_test.cc create mode 100644 mace/ops/scalar_math.cc create mode 100644 mace/ops/scalar_math.h create mode 100644 mace/ops/scalar_math_benchmark.cc create mode 100644 mace/ops/scalar_math_test.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 60eabfda..fcf9069d 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,12 +83,14 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry); extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); +extern void Register_Neg(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); +extern void Register_ScalarMath(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); @@ -120,12 +122,14 @@ OperatorRegistry::OperatorRegistry() { ops::Register_GlobalAvgPooling(this); ops::Register_ImageToBuffer(this); ops::Register_MatMul(this); + ops::Register_Neg(this); ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); ops::Register_ReOrganize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); + ops::Register_ScalarMath(this); ops::Register_Slice(this); ops::Register_Softmax(this); ops::Register_SpaceToBatchND(this); diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 15f81a3c..09e3d9ea 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -41,7 +41,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { StatsFuture *future) { Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input1_guard(input1); - Tensor::MappingGuard output_guard(output); + Tensor::MappingGuard output_guard(output); const T *input0_ptr = input0->data(); const T *input1_ptr = input1->data(); @@ -56,12 +56,12 @@ struct EltwiseFunctor : EltwiseFunctorBase { } break; case SUM: - if (coeff_.empty()) { + if (coeff_.empty()) { #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { + for (index_t i = 0; i < size; ++i) { output_ptr[i] = input0_ptr[i] + input1_ptr[i]; } - } else { + } else { #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = @@ -69,13 +69,13 @@ struct EltwiseFunctor : EltwiseFunctorBase { } } break; - case MAX: + case MAX: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::max(input0_ptr[i], input1_ptr[i]); } break; - case MIN: + case MIN: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); diff --git a/mace/kernels/negative.h b/mace/kernels/negative.h new file mode 100644 index 00000000..544a1854 --- /dev/null +++ b/mace/kernels/negative.h @@ -0,0 +1,69 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_KERNELS_NEGATIVE_H_ +#define MACE_KERNELS_NEGATIVE_H_ + +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" + +namespace mace { +namespace kernels { + +template +struct NegFunctor { + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channels = input->dim(3); + + Tensor::MappingGuard input_mapper(input); + Tensor::MappingGuard output_mapper(output); + + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + +#pragma omp parallel for collapse(4) + for (index_t n = 0; n < batch; ++n) { + for (index_t h = 0; h < height; ++h) { + for (index_t w = 0; w < width; ++w) { + for (index_t c = 0; c < channels; ++c) { + index_t pos = (((n * height) + h) * width + w) * channels + c; + output_ptr[pos] = 0 - input_ptr[pos]; + } + } + } + } + } +}; + +/* +template <> +void NegFunctor::operator()( + const Tensor *input, + const Tensor *bias, + Tensor *output, + StatsFuture *future); +*/ + +template +struct NegFunctor { + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); + cl::Kernel kernel_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_NEGATIVE_H_ diff --git a/mace/kernels/opencl/cl/neg.cl b/mace/kernels/opencl/cl/neg.cl new file mode 100644 index 00000000..7b539dda --- /dev/null +++ b/mace/kernels/opencl/cl/neg.cl @@ -0,0 +1,14 @@ +#include +// Supported data types: half/float +__kernel void neg(__read_only image2d_t input, + __write_only image2d_t output) { + const int ch_blk = get_global_id(0); + const int w = get_global_id(1); + const int hb = get_global_id(2); + const int width = get_global_size(1); + + const int pos = mad24(ch_blk, width, w); + DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); + DATA_TYPE4 out = 0 - in; + WRITE_IMAGET(output, (int2)(pos, hb), out); +} diff --git a/mace/kernels/opencl/cl/scalar_math.cl b/mace/kernels/opencl/cl/scalar_math.cl new file mode 100644 index 00000000..19678b08 --- /dev/null +++ b/mace/kernels/opencl/cl/scalar_math.cl @@ -0,0 +1,27 @@ +#include + +__kernel void scalar_math(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __private const float scalar, + __write_only image2d_t output) { + const int w = get_global_id(0); + const int hb = get_global_id(1); + + DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); + DATA_TYPE4 in1; + in1.x = scalar; + in1.y = scalar; + in1.z = scalar; + in1.w = scalar; + DATA_TYPE4 out; +#if SCALAR_MATH_TYPE == 1 + out = in0 + in1; +#elif SCALAR_MATH_TYPE == 4 + out = in0 - in1; +#elif SCALAR_MATH_TYPE == 0 + out = in0 * in1; +#elif SCALAR_MATH_TYPE == 5 + out = in0 / in1; +#endif + + WRITE_IMAGET(output, (int2)(w, hb), out); +} diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc index c22a6e4c..6d752642 100644 --- a/mace/kernels/opencl/depth_to_space_opencl.cc +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -21,6 +21,7 @@ void DepthToSpaceOpFunctor::operator()( const index_t input_depth = input->dim(3); const char *kernel_name = nullptr; + index_t kernel_width = input_width; index_t output_height, output_width, output_depth; if (d2s_) { @@ -28,11 +29,13 @@ void DepthToSpaceOpFunctor::operator()( output_width = input_width * block_size_; output_depth = input_depth / (block_size_ * block_size_); kernel_name = "depth_to_space"; + kernel_width = output_width; } else { output_height = input_height / block_size_; output_width = input_width / block_size_; output_depth = input_depth * block_size_ * block_size_; kernel_name = "space_to_depth"; + kernel_width = input_width; } const index_t input_depth_blocks = RoundUpDiv4(input_depth); const index_t output_depth_blocks = RoundUpDiv4(output_depth); diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index c23534bb..132549a3 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -22,7 +22,7 @@ void EltwiseFunctor::operator()(const Tensor *input0, const index_t channel_blocks = RoundUpDiv4(channels); const index_t width_pixels = channel_blocks * width; - const index_t batch_height_pixels = batch * height; + const index_t batch_height_pixels = batch * height; const uint32_t gws[2] = {static_cast(width_pixels), static_cast(batch_height_pixels)}; diff --git a/mace/kernels/opencl/negative_opencl.cc b/mace/kernels/opencl/negative_opencl.cc new file mode 100644 index 00000000..70f866d8 --- /dev/null +++ b/mace/kernels/opencl/negative_opencl.cc @@ -0,0 +1,65 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/negative.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +template +void NegFunctor::operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + + auto runtime = OpenCLRuntime::Global(); + if (kernel_.get() == nullptr) { + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("neg"); + built_options.emplace("-Dneg=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + kernel_ = runtime->BuildKernel("neg", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input->shape(); + } + + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), + static_cast(height * batch)}; + const std::vector lws = {8, 16, 8}; + + cl::Event event; + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); + MACE_CHECK(error == CL_SUCCESS); + if (future != nullptr) { + future->wait_fn = [runtime, event](CallStats *stats) { + event.wait(); + if (stats != nullptr) { + runtime->GetCallStats(event, stats); + } + }; + } +} + +template struct NegFunctor; +template struct NegFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/scalar_math_opencl.cc b/mace/kernels/opencl/scalar_math_opencl.cc new file mode 100644 index 00000000..42ad7518 --- /dev/null +++ b/mace/kernels/opencl/scalar_math_opencl.cc @@ -0,0 +1,57 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/scalar_math.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" + +namespace mace { +namespace kernels { + +template +void ScalarMathFunctor::operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + const index_t batch = input->dim(0); + const index_t height = input->dim(1); + const index_t width = input->dim(2); + const index_t channels = input->dim(3); + + const index_t channel_blocks = RoundUpDiv4(channels); + const index_t width_pixels = channel_blocks * width; + const index_t batch_height_pixels = batch * height; + + if (kernel_.get() == nullptr) { + auto runtime = OpenCLRuntime::Global(); + std::set built_options; + auto dt = DataTypeToEnum::value; + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("scalar_math"); + built_options.emplace("-Dscalar_math=" + kernel_name); + built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); + built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); + built_options.emplace(MakeString("-DSCALAR_MATH_TYPE=", type_)); + kernel_ = runtime->BuildKernel("scalar_math", kernel_name, built_options); + } + if (!IsVecEqual(input_shape_, input->shape())) { + uint32_t idx = 0; + kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, static_cast(coeff_)); + kernel_.setArg(idx++, *(output->opencl_image())); + input_shape_ = input->shape(); + } + + const uint32_t gws[2] = {static_cast(width_pixels), + static_cast(batch_height_pixels)}; + const std::vector lws = {64, 16, 1}; + std::stringstream ss; + ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) + << "_" << output->dim(2) << "_" << output->dim(3); + TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); +} + +template struct ScalarMathFunctor; +template struct ScalarMathFunctor; +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/scalar_math.h b/mace/kernels/scalar_math.h new file mode 100644 index 00000000..75c52867 --- /dev/null +++ b/mace/kernels/scalar_math.h @@ -0,0 +1,96 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// +#ifndef MACE_KERNELS_SCALAR_MATH_H_ +#define MACE_KERNELS_SCALAR_MATH_H_ + +#include +#include + +#include "mace/core/future.h" +#include "mace/core/runtime/opencl/cl2_header.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +enum ScalarMathType { + MUL = 0, + ADD = 1, + MAX = 2, + MIN = 3, + SUB = 4, + DIV = 5, +}; + +struct ScalarMathFunctorBase { + ScalarMathFunctorBase(const ScalarMathType type, const float coeff) + : type_(type), coeff_(coeff) {} + + ScalarMathType type_; + float coeff_; +}; + +template +struct ScalarMathFunctor : ScalarMathFunctorBase { + ScalarMathFunctor(const ScalarMathType type, const float coeff) + : ScalarMathFunctorBase(type, coeff) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future) { + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + + const T *input_ptr = input->data(); + T *output_ptr = output->mutable_data(); + const index_t size = input->size(); + + switch (type_) { + case MUL: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = coeff_ * input_ptr[i]; + } + break; + case ADD: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = coeff_ + input_ptr[i]; + } + break; + case SUB: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input_ptr[i] - coeff_; + } + break; + case DIV: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = input_ptr[i] / coeff_; + } + break; + default: + LOG(FATAL) << "ScalarMath op not support type " << type_; + } + } +}; + +template +struct ScalarMathFunctor : ScalarMathFunctorBase { + ScalarMathFunctor(const ScalarMathType type, const float coeff) + : ScalarMathFunctorBase(type, coeff) {} + + void operator()(const Tensor *input, + Tensor *output, + StatsFuture *future); + + cl::Kernel kernel_; + std::vector input_shape_; +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_SCALAR_MATH_H_ diff --git a/mace/ops/neg.cc b/mace/ops/neg.cc new file mode 100644 index 00000000..c4dffe70 --- /dev/null +++ b/mace/ops/neg.cc @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/neg.h" + +namespace mace { +namespace ops { + +void Register_Neg(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + NegOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + NegOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + NegOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/neg.h b/mace/ops/neg.h new file mode 100644 index 00000000..0e3be04c --- /dev/null +++ b/mace/ops/neg.h @@ -0,0 +1,39 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_NEG_H_ +#define MACE_OPS_NEG_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/negative.h" + +namespace mace { +namespace ops { + +template +class NegOp : public Operator { + public: + NegOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + functor_() {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(0); + Tensor *output_tensor = this->outputs_[0]; + output_tensor->ResizeLike(input_tensor); + + functor_(input_tensor, output_tensor, future); + return true; + } + + private: + kernels::NegFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_NEGATIVE_H_ diff --git a/mace/ops/neg_benchmark.cc b/mace/ops/neg_benchmark.cc new file mode 100644 index 00000000..4d8f6cce --- /dev/null +++ b/mace/ops/neg_benchmark.cc @@ -0,0 +1,82 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void Neg(int iters, int batch, int channels, int height, int width) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("Neg", "NegBM") + .Input("InputImage") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("Neg", "NegBM") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_NEG_MACRO(N, C, H, W, TYPE, DEVICE) \ + static void BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + Neg(iters, N, C, H, W); \ + } \ + BENCHMARK(BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_NEG(N, C, H, W) \ + BM_NEG_MACRO(N, C, H, W, float, CPU); \ + BM_NEG_MACRO(N, C, H, W, float, OPENCL); \ + BM_NEG_MACRO(N, C, H, W, half, OPENCL); + +BM_NEG(1, 1, 512, 512); +BM_NEG(1, 3, 128, 128); +BM_NEG(1, 3, 512, 512); +BM_NEG(1, 32, 112, 112); +BM_NEG(1, 64, 256, 256); +BM_NEG(1, 64, 512, 512); +BM_NEG(1, 128, 56, 56); +BM_NEG(1, 128, 256, 256); +BM_NEG(1, 256, 14, 14); +BM_NEG(1, 512, 14, 14); +BM_NEG(1, 1024, 7, 7); +BM_NEG(32, 1, 256, 256); +BM_NEG(32, 3, 256, 256); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/neg_test.cc b/mace/ops/neg_test.cc new file mode 100644 index 00000000..c7ae15f6 --- /dev/null +++ b/mace/ops/neg_test.cc @@ -0,0 +1,61 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class NegOpTest : public OpsTestBase {}; + +template +void NegSimple() { + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 6, 2, 1}, + {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("Neg", "NegTest") + .Input("InputImage") + .Output("OutputImage") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + + // Transfer output + ImageToBuffer(&net, "OutputImage", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } else { + OpDefBuilder("Neg", "NegTest") + .Input("Input") + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } + + // Check + auto expected = CreateTensor( + {1, 6, 2, 1}, + {-5, -5, -7, -7, -9, -9, -11, -11, -13, -13, -15, -15}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); +} + +TEST_F(NegOpTest, NegSimpleCPU) { NegSimple(); } + +TEST_F(NegOpTest, NegSimpleOPENCL) { + NegSimple(); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/scalar_math.cc b/mace/ops/scalar_math.cc new file mode 100644 index 00000000..9891994f --- /dev/null +++ b/mace/ops/scalar_math.cc @@ -0,0 +1,31 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/scalar_math.h" + +namespace mace { +namespace ops { + +void Register_ScalarMath(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ScalarMathOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + ScalarMathOp); + + REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") + .Device(DeviceType::OPENCL) + .TypeConstraint("T") + .Build(), + ScalarMathOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/scalar_math.h b/mace/ops/scalar_math.h new file mode 100644 index 00000000..2f0f4394 --- /dev/null +++ b/mace/ops/scalar_math.h @@ -0,0 +1,49 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_SCALAR_MATH_H_ +#define MACE_OPS_SCALAR_MATH_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/scalar_math.h" + +namespace mace { +namespace ops { + +template +class ScalarMathOp : public Operator { + public: + ScalarMathOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws), + x_(OperatorBase::GetSingleArgument("x", 1.0)), + functor_(static_cast( + OperatorBase::GetSingleArgument( + "type", static_cast( + kernels::ScalarMathType::ADD))), + this->x_) {} + + bool Run(StatsFuture *future) override { + const Tensor *input_tensor = this->Input(INPUT); + Tensor *output_tensor = this->Output(OUTPUT); + output_tensor->ResizeLike(input_tensor); + + functor_(input_tensor, output_tensor, future); + return true; + } + + protected: + const float x_; + OP_INPUT_TAGS(INPUT); + OP_OUTPUT_TAGS(OUTPUT); + + private: + kernels::ScalarMathFunctor functor_; +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_SCALAR_MATH_H_ diff --git a/mace/ops/scalar_math_benchmark.cc b/mace/ops/scalar_math_benchmark.cc new file mode 100644 index 00000000..90351326 --- /dev/null +++ b/mace/ops/scalar_math_benchmark.cc @@ -0,0 +1,88 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +template +static void ScalarMath(int iters, int batch, int channels, + int height, int width, float x, int type) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {batch, height, width, channels}); + + if (D == DeviceType::OPENCL) { + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ScalarMath", "ScalarMathBM") + .Input("InputImage") + .Output("Output") + .AddIntArg("type", type) + .AddFloatArg("x", x) + .Finalize(net.NewOperatorDef()); + } else { + OpDefBuilder("ScalarMath", "ScalarMathBM") + .Input("Input") + .Output("Output") + .AddIntArg("type", type) + .AddFloatArg("x", x) + .Finalize(net.NewOperatorDef()); + } + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} + +#define BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ + static void \ + BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + ScalarMath(iters, N, C, H, W, X, G); \ + } \ + BENCHMARK( \ + BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) + +#define BM_SCALAR_MATH(N, C, H, W, X, G) \ + BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, CPU); \ + BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, OPENCL); \ + BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, half, OPENCL); + +BM_SCALAR_MATH(1, 1, 512, 512, 2, 0); +BM_SCALAR_MATH(1, 3, 128, 128, 2, 1); +BM_SCALAR_MATH(1, 3, 512, 512, 2, 2); +BM_SCALAR_MATH(1, 32, 112, 112, 2, 3); +BM_SCALAR_MATH(1, 64, 256, 256, 3, 0); +BM_SCALAR_MATH(1, 64, 512, 512, 3, 1); +BM_SCALAR_MATH(1, 128, 56, 56, 3, 2); +BM_SCALAR_MATH(1, 128, 256, 256, 3, 3); +BM_SCALAR_MATH(1, 256, 14, 14, 3, 0); +BM_SCALAR_MATH(1, 512, 14, 14, 3, 1); +BM_SCALAR_MATH(1, 1024, 7, 7, 3, 2); +BM_SCALAR_MATH(32, 1, 256, 256, 3, 3); +BM_SCALAR_MATH(32, 3, 256, 256, 3, 2); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/scalar_math_test.cc b/mace/ops/scalar_math_test.cc new file mode 100644 index 00000000..3da6176d --- /dev/null +++ b/mace/ops/scalar_math_test.cc @@ -0,0 +1,160 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" +#include "../kernels/scalar_math.h" + +namespace mace { +namespace ops { +namespace test { + +class ScalarMathOpTest : public OpsTestBase {}; + + +template +void Simple(const kernels::ScalarMathType type, + const std::vector &shape, + const std::vector &input0, + const float x, + const std::vector &output) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input1", shape, input0); + + if (D == DeviceType::CPU) { + OpDefBuilder("ScalarMath", "ScalarMathTest") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + } else { + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + OpDefBuilder("ScalarMath", "ScalarMathTest") + .Input("InputImg1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); + } + + auto expected = CreateTensor(shape, output); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-3); +} + +TEST_F(ScalarMathOpTest, CPUSimple) { + Simple(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); + + Simple(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); + + Simple(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); + + Simple(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); +} + +TEST_F(ScalarMathOpTest, GPUSimple) { + Simple(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); + + Simple(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); + + Simple(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); + + Simple(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); +} + +template +void RandomTest(const kernels::ScalarMathType type, + const std::vector &shape) { + testing::internal::LogToStderr(); + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input1", shape); + + OpDefBuilder("ScalarMath", "ScalarMathTest") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", 1.2) + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + BufferToImage(&net, "Input1", "InputImg1", + kernels::BufferType::IN_OUT_CHANNEL); + + OpDefBuilder("ScalarMath", "ScalarMathTest") + .Input("InputImg1") + .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", 1.2) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Output("OutputImg") + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(D); + + ImageToBuffer(&net, "OutputImg", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); + + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-3); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-1); + } +} + +TEST_F(ScalarMathOpTest, OPENCLRandomFloat) { + RandomTest(kernels::ScalarMathType::MUL, + {3, 23, 37, 19}); + RandomTest(kernels::ScalarMathType::ADD, + {13, 32, 32, 64}); + RandomTest(kernels::ScalarMathType::SUB, + {3, 32, 32, 64}); + RandomTest(kernels::ScalarMathType::DIV, + {13, 32, 32, 64}); +} + +TEST_F(ScalarMathOpTest, OPENCLRandomHalf) { + RandomTest(kernels::ScalarMathType::MUL, + {3, 23, 37, 19}); + RandomTest(kernels::ScalarMathType::ADD, + {13, 32, 32, 64}); + RandomTest(kernels::ScalarMathType::SUB, + {3, 32, 32, 64}); + RandomTest(kernels::ScalarMathType::DIV, + {13, 32, 32, 64}); +} + +} // namespace test +} // namespace ops +} // namespace mace -- GitLab