From 13bca3d11715a2134cb6d690cf49f4fe2726a5eb Mon Sep 17 00:00:00 2001 From: Unknown Date: Thu, 29 Mar 2018 14:16:34 +0800 Subject: [PATCH] optimize code for cwise --- mace/core/operator.cc | 6 +- mace/kernels/{scalar_math.h => cwise.h} | 55 +++++++++---- mace/kernels/eltwise.h | 12 +-- mace/kernels/negative.h | 69 ---------------- mace/kernels/opencl/cl/cwise.cl | 42 ++++++++++ mace/kernels/opencl/cl/depth_to_space.cl | 30 +++++-- mace/kernels/opencl/cl/neg.cl | 14 ---- mace/kernels/opencl/cl/scalar_math.cl | 27 ------ ...{scalar_math_opencl.cc => cwise_opencl.cc} | 18 ++-- mace/kernels/opencl/depth_to_space_opencl.cc | 15 ++-- mace/kernels/opencl/eltwise_opencl.cc | 2 +- mace/kernels/opencl/negative_opencl.cc | 65 --------------- mace/ops/conv_2d.h | 1 - mace/ops/{neg.cc => cwise.cc} | 16 ++-- mace/ops/{scalar_math.h => cwise.h} | 18 ++-- ...r_math_benchmark.cc => cwise_benchmark.cc} | 54 ++++++------ .../{scalar_math_test.cc => cwise_test.cc} | 72 +++++++++------- mace/ops/neg.h | 39 --------- mace/ops/neg_benchmark.cc | 82 ------------------- mace/ops/neg_test.cc | 61 -------------- mace/ops/scalar_math.cc | 31 ------- mace/python/tools/tf_converter_lib.py | 70 ++++++++-------- 22 files changed, 259 insertions(+), 540 deletions(-) rename mace/kernels/{scalar_math.h => cwise.h} (56%) delete mode 100644 mace/kernels/negative.h create mode 100644 mace/kernels/opencl/cl/cwise.cl delete mode 100644 mace/kernels/opencl/cl/neg.cl delete mode 100644 mace/kernels/opencl/cl/scalar_math.cl rename mace/kernels/opencl/{scalar_math_opencl.cc => cwise_opencl.cc} (73%) delete mode 100644 mace/kernels/opencl/negative_opencl.cc rename mace/ops/{neg.cc => cwise.cc} (61%) rename mace/ops/{scalar_math.h => cwise.h} (67%) rename mace/ops/{scalar_math_benchmark.cc => cwise_benchmark.cc} (59%) rename mace/ops/{scalar_math_test.cc => cwise_test.cc} (61%) delete mode 100644 mace/ops/neg.h delete mode 100644 mace/ops/neg_benchmark.cc delete mode 100644 mace/ops/neg_test.cc delete mode 100644 mace/ops/scalar_math.cc diff --git a/mace/core/operator.cc b/mace/core/operator.cc index d98b999f..60eabfda 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -73,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); @@ -82,14 +83,12 @@ extern void Register_FusedConv2D(OperatorRegistry *op_registry); extern void Register_GlobalAvgPooling(OperatorRegistry *op_registry); extern void Register_ImageToBuffer(OperatorRegistry *op_registry); extern void Register_MatMul(OperatorRegistry *op_registry); -extern void Register_Neg(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); -extern void Register_ScalarMath(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); extern void Register_Softmax(OperatorRegistry *op_registry); extern void Register_SpaceToBatchND(OperatorRegistry *op_registry); @@ -111,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); + ops::Register_CWise(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); ops::Register_Eltwise(this); @@ -120,14 +120,12 @@ OperatorRegistry::OperatorRegistry() { ops::Register_GlobalAvgPooling(this); ops::Register_ImageToBuffer(this); ops::Register_MatMul(this); - ops::Register_Neg(this); ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); ops::Register_ReOrganize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); - ops::Register_ScalarMath(this); ops::Register_Slice(this); ops::Register_Softmax(this); ops::Register_SpaceToBatchND(this); diff --git a/mace/kernels/scalar_math.h b/mace/kernels/cwise.h similarity index 56% rename from mace/kernels/scalar_math.h rename to mace/kernels/cwise.h index 75c52867..073f5c48 100644 --- a/mace/kernels/scalar_math.h +++ b/mace/kernels/cwise.h @@ -1,8 +1,8 @@ // // Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_KERNELS_SCALAR_MATH_H_ -#define MACE_KERNELS_SCALAR_MATH_H_ +#ifndef MACE_KERNELS_CWISE_H_ +#define MACE_KERNELS_CWISE_H_ #include #include @@ -14,27 +14,29 @@ namespace mace { namespace kernels { -enum ScalarMathType { +enum CWiseType { MUL = 0, ADD = 1, MAX = 2, MIN = 3, SUB = 4, DIV = 5, + NEG = 6, + ABS = 7, }; -struct ScalarMathFunctorBase { - ScalarMathFunctorBase(const ScalarMathType type, const float coeff) +struct CWiseFunctorBase { + CWiseFunctorBase(const CWiseType type, const float coeff) : type_(type), coeff_(coeff) {} - ScalarMathType type_; + CWiseType type_; float coeff_; }; template -struct ScalarMathFunctor : ScalarMathFunctorBase { - ScalarMathFunctor(const ScalarMathType type, const float coeff) - : ScalarMathFunctorBase(type, coeff) {} +struct CWiseFunctor : CWiseFunctorBase { + CWiseFunctor(const CWiseType type, const float coeff) + : CWiseFunctorBase(type, coeff) {} void operator()(const Tensor *input, Tensor *output, @@ -59,6 +61,18 @@ struct ScalarMathFunctor : ScalarMathFunctorBase { output_ptr[i] = coeff_ + input_ptr[i]; } break; + case MAX: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = std::max(input_ptr[i], coeff_); + } + break; + case MIN: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = std::min(input_ptr[i], coeff_); + } + break; case SUB: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { @@ -71,16 +85,29 @@ struct ScalarMathFunctor : ScalarMathFunctorBase { output_ptr[i] = input_ptr[i] / coeff_; } break; + case NEG: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = 0 - input_ptr[i]; + } + break; + case ABS: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + T val = input_ptr[i]; + output_ptr[i] = (val > 0)? val : 0 - val; + } + break; default: - LOG(FATAL) << "ScalarMath op not support type " << type_; + LOG(FATAL) << "CWise op not support type " << type_; } } }; template -struct ScalarMathFunctor : ScalarMathFunctorBase { - ScalarMathFunctor(const ScalarMathType type, const float coeff) - : ScalarMathFunctorBase(type, coeff) {} +struct CWiseFunctor : CWiseFunctorBase { + CWiseFunctor(const CWiseType type, const float coeff) + : CWiseFunctorBase(type, coeff) {} void operator()(const Tensor *input, Tensor *output, @@ -93,4 +120,4 @@ struct ScalarMathFunctor : ScalarMathFunctorBase { } // namespace kernels } // namespace mace -#endif // MACE_KERNELS_SCALAR_MATH_H_ +#endif // MACE_KERNELS_CWISE_H_ diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 09e3d9ea..15f81a3c 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -41,7 +41,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { StatsFuture *future) { Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input1_guard(input1); - Tensor::MappingGuard output_guard(output); + Tensor::MappingGuard output_guard(output); const T *input0_ptr = input0->data(); const T *input1_ptr = input1->data(); @@ -56,12 +56,12 @@ struct EltwiseFunctor : EltwiseFunctorBase { } break; case SUM: - if (coeff_.empty()) { + if (coeff_.empty()) { #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { + for (index_t i = 0; i < size; ++i) { output_ptr[i] = input0_ptr[i] + input1_ptr[i]; } - } else { + } else { #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = @@ -69,13 +69,13 @@ struct EltwiseFunctor : EltwiseFunctorBase { } } break; - case MAX: + case MAX: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::max(input0_ptr[i], input1_ptr[i]); } break; - case MIN: + case MIN: #pragma omp parallel for for (index_t i = 0; i < size; ++i) { output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); diff --git a/mace/kernels/negative.h b/mace/kernels/negative.h deleted file mode 100644 index 544a1854..00000000 --- a/mace/kernels/negative.h +++ /dev/null @@ -1,69 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_KERNELS_NEGATIVE_H_ -#define MACE_KERNELS_NEGATIVE_H_ - -#include - -#include "mace/core/future.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/tensor.h" -#include "mace/public/mace.h" - -namespace mace { -namespace kernels { - -template -struct NegFunctor { - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t height = input->dim(1); - const index_t width = input->dim(2); - const index_t channels = input->dim(3); - - Tensor::MappingGuard input_mapper(input); - Tensor::MappingGuard output_mapper(output); - - const T *input_ptr = input->data(); - T *output_ptr = output->mutable_data(); - -#pragma omp parallel for collapse(4) - for (index_t n = 0; n < batch; ++n) { - for (index_t h = 0; h < height; ++h) { - for (index_t w = 0; w < width; ++w) { - for (index_t c = 0; c < channels; ++c) { - index_t pos = (((n * height) + h) * width + w) * channels + c; - output_ptr[pos] = 0 - input_ptr[pos]; - } - } - } - } - } -}; - -/* -template <> -void NegFunctor::operator()( - const Tensor *input, - const Tensor *bias, - Tensor *output, - StatsFuture *future); -*/ - -template -struct NegFunctor { - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future); - cl::Kernel kernel_; - std::vector input_shape_; -}; - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_NEGATIVE_H_ diff --git a/mace/kernels/opencl/cl/cwise.cl b/mace/kernels/opencl/cl/cwise.cl new file mode 100644 index 00000000..16f1f085 --- /dev/null +++ b/mace/kernels/opencl/cl/cwise.cl @@ -0,0 +1,42 @@ +#include + +__kernel void cwise(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __private const float value, + __write_only image2d_t output) { + const int w = get_global_id(0); + const int hb = get_global_id(1); + + DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); + DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value}; + DATA_TYPE4 out; + +#if CWISE_TYPE == 0 + out = in0 * in1; +#elif CWISE_TYPE == 1 + out = in0 + in1; +#elif CWISE_TYPE == 2 + out.x = fmax(in0.x, value); + out.y = fmax(in0.y, value); + out.z = fmax(in0.z, value); + out.z = fmax(in0.w, value); +#elif CWISE_TYPE == 3 + out.x = fmin(in0.x, value); + out.y = fmin(in0.y, value); + out.z = fmin(in0.z, value); + out.z = fmin(in0.w, value); +#elif CWISE_TYPE == 4 + out = in0 - in1; +#elif CWISE_TYPE == 5 + out = in0 / in1; +#elif CWISE_TYPE == 6 + in1 = (DATA_TYPE4)(0, 0, 0, 0); + out = in1 - in0; +#elif CWISE_TYPE == 7 + out.x = fabs(in0.x); + out.y = fabs(in0.y); + out.z = fabs(in0.z); + out.w = fabs(in0.w); +#endif + + WRITE_IMAGET(output, (int2)(w, hb), out); +} diff --git a/mace/kernels/opencl/cl/depth_to_space.cl b/mace/kernels/opencl/cl/depth_to_space.cl index a9e58be5..3be46e1e 100644 --- a/mace/kernels/opencl/cl/depth_to_space.cl +++ b/mace/kernels/opencl/cl/depth_to_space.cl @@ -4,7 +4,12 @@ __kernel void depth_to_space( UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 __read_only image2d_t input, __private const int block_size, - __private const int output_depth, + __private const int input_height, + __private const int input_width, + __private const int input_depth_blocks, + __private const int output_height, + __private const int output_width, + __private const int output_depth_blocks, __write_only image2d_t output) { const int out_d = get_global_id(0); const int out_w = get_global_id(1); @@ -20,16 +25,21 @@ __kernel void depth_to_space( const int output_width = get_global_size(1); #endif + if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width) + return; + const int out_pos = mad24(out_d, output_width, out_w); - const int input_width = output_width / block_size; const int in_h = out_h / block_size; const int offset_h = out_h % block_size; const int in_w = out_w / block_size; const int offset_w = out_w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * output_depth; + const int offset_d = (offset_h * block_size + offset_w) * output_depth_blocks; const int in_d = out_d + offset_d; + if (in_h >= input_height || in_w >= input_width || in_d >= input_depth_blocks) + return; + const int in_pos = mad24(in_d, input_width, in_w); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, in_h)); WRITE_IMAGET(output, (int2)(out_pos, out_h), in_data); @@ -39,7 +49,12 @@ __kernel void space_to_depth( UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 __read_only image2d_t input, __private const int block_size, - __private const int input_depth, + __private const int input_height, + __private const int input_width, + __private const int input_depth_blocks, + __private const int output_height, + __private const int output_width, + __private const int output_depth_blocks, __write_only image2d_t output) { const int d = get_global_id(0); @@ -57,14 +72,17 @@ __kernel void space_to_depth( #endif const int in_pos = mad24(d, input_width, w); - const int output_width = input_width / block_size; const int out_h = h / block_size; const int offset_h = h % block_size; const int out_w = w / block_size; const int offset_w = w % block_size; - const int offset_d = (offset_h * block_size + offset_w) * input_depth; + const int offset_d = (offset_h * block_size + offset_w) * input_depth_blocks; const int out_d = d + offset_d; + + if (out_d >= output_depth_blocks || out_h >= output_height || out_w >= output_width) + return; + const int out_pos = mad24(out_d, output_width, out_w); DATA_TYPE4 in_data = READ_IMAGET(input, SAMPLER, (int2)(in_pos, h)); diff --git a/mace/kernels/opencl/cl/neg.cl b/mace/kernels/opencl/cl/neg.cl deleted file mode 100644 index 7b539dda..00000000 --- a/mace/kernels/opencl/cl/neg.cl +++ /dev/null @@ -1,14 +0,0 @@ -#include -// Supported data types: half/float -__kernel void neg(__read_only image2d_t input, - __write_only image2d_t output) { - const int ch_blk = get_global_id(0); - const int w = get_global_id(1); - const int hb = get_global_id(2); - const int width = get_global_size(1); - - const int pos = mad24(ch_blk, width, w); - DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); - DATA_TYPE4 out = 0 - in; - WRITE_IMAGET(output, (int2)(pos, hb), out); -} diff --git a/mace/kernels/opencl/cl/scalar_math.cl b/mace/kernels/opencl/cl/scalar_math.cl deleted file mode 100644 index 19678b08..00000000 --- a/mace/kernels/opencl/cl/scalar_math.cl +++ /dev/null @@ -1,27 +0,0 @@ -#include - -__kernel void scalar_math(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ - __private const float scalar, - __write_only image2d_t output) { - const int w = get_global_id(0); - const int hb = get_global_id(1); - - DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); - DATA_TYPE4 in1; - in1.x = scalar; - in1.y = scalar; - in1.z = scalar; - in1.w = scalar; - DATA_TYPE4 out; -#if SCALAR_MATH_TYPE == 1 - out = in0 + in1; -#elif SCALAR_MATH_TYPE == 4 - out = in0 - in1; -#elif SCALAR_MATH_TYPE == 0 - out = in0 * in1; -#elif SCALAR_MATH_TYPE == 5 - out = in0 / in1; -#endif - - WRITE_IMAGET(output, (int2)(w, hb), out); -} diff --git a/mace/kernels/opencl/scalar_math_opencl.cc b/mace/kernels/opencl/cwise_opencl.cc similarity index 73% rename from mace/kernels/opencl/scalar_math_opencl.cc rename to mace/kernels/opencl/cwise_opencl.cc index 42ad7518..325f2f41 100644 --- a/mace/kernels/opencl/scalar_math_opencl.cc +++ b/mace/kernels/opencl/cwise_opencl.cc @@ -2,7 +2,7 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#include "mace/kernels/scalar_math.h" +#include "mace/kernels/cwise.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/tuner.h" @@ -11,7 +11,7 @@ namespace mace { namespace kernels { template -void ScalarMathFunctor::operator()(const Tensor *input, +void CWiseFunctor::operator()(const Tensor *input, Tensor *output, StatsFuture *future) { const index_t batch = input->dim(0); @@ -27,12 +27,12 @@ void ScalarMathFunctor::operator()(const Tensor *input, auto runtime = OpenCLRuntime::Global(); std::set built_options; auto dt = DataTypeToEnum::value; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("scalar_math"); - built_options.emplace("-Dscalar_math=" + kernel_name); + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise"); + built_options.emplace("-Dcwise=" + kernel_name); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - built_options.emplace(MakeString("-DSCALAR_MATH_TYPE=", type_)); - kernel_ = runtime->BuildKernel("scalar_math", kernel_name, built_options); + built_options.emplace(MakeString("-DCWISE_TYPE=", type_)); + kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options); } if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; @@ -46,12 +46,12 @@ void ScalarMathFunctor::operator()(const Tensor *input, static_cast(batch_height_pixels)}; const std::vector lws = {64, 16, 1}; std::stringstream ss; - ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) + ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); } -template struct ScalarMathFunctor; -template struct ScalarMathFunctor; +template struct CWiseFunctor; +template struct CWiseFunctor; } // namespace kernels } // namespace mace diff --git a/mace/kernels/opencl/depth_to_space_opencl.cc b/mace/kernels/opencl/depth_to_space_opencl.cc index 7d7329f0..c22a6e4c 100644 --- a/mace/kernels/opencl/depth_to_space_opencl.cc +++ b/mace/kernels/opencl/depth_to_space_opencl.cc @@ -20,26 +20,22 @@ void DepthToSpaceOpFunctor::operator()( const index_t input_width = input->dim(2); const index_t input_depth = input->dim(3); - int depth_blocks = 1; const char *kernel_name = nullptr; - index_t kernel_width = input_width; index_t output_height, output_width, output_depth; if (d2s_) { output_height = input_height * block_size_; output_width = input_width * block_size_; output_depth = input_depth / (block_size_ * block_size_); - depth_blocks = RoundUpDiv4(output_depth); kernel_name = "depth_to_space"; - kernel_width = output_width; } else { output_height = input_height / block_size_; output_width = input_width / block_size_; output_depth = input_depth * block_size_ * block_size_; - depth_blocks = RoundUpDiv4(input_depth); kernel_name = "space_to_depth"; - kernel_width = input_width; } + const index_t input_depth_blocks = RoundUpDiv4(input_depth); + const index_t output_depth_blocks = RoundUpDiv4(output_depth); std::vector output_shape = {batch, output_height, output_width, output_depth}; @@ -94,7 +90,12 @@ void DepthToSpaceOpFunctor::operator()( } kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, static_cast(block_size_)); - kernel_.setArg(idx++, static_cast(depth_blocks)); + kernel_.setArg(idx++, static_cast(input_height)); + kernel_.setArg(idx++, static_cast(input_width)); + kernel_.setArg(idx++, static_cast(input_depth_blocks)); + kernel_.setArg(idx++, static_cast(output_height)); + kernel_.setArg(idx++, static_cast(output_width)); + kernel_.setArg(idx++, static_cast(output_depth_blocks)); kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 132549a3..c23534bb 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -22,7 +22,7 @@ void EltwiseFunctor::operator()(const Tensor *input0, const index_t channel_blocks = RoundUpDiv4(channels); const index_t width_pixels = channel_blocks * width; - const index_t batch_height_pixels = batch * height; + const index_t batch_height_pixels = batch * height; const uint32_t gws[2] = {static_cast(width_pixels), static_cast(batch_height_pixels)}; diff --git a/mace/kernels/opencl/negative_opencl.cc b/mace/kernels/opencl/negative_opencl.cc deleted file mode 100644 index 70f866d8..00000000 --- a/mace/kernels/opencl/negative_opencl.cc +++ /dev/null @@ -1,65 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/kernels/negative.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/kernels/opencl/helper.h" -#include "mace/utils/utils.h" - -namespace mace { -namespace kernels { - -template -void NegFunctor::operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t height = input->dim(1); - const index_t width = input->dim(2); - const index_t channels = input->dim(3); - - const index_t channel_blocks = RoundUpDiv4(channels); - - auto runtime = OpenCLRuntime::Global(); - if (kernel_.get() == nullptr) { - std::set built_options; - auto dt = DataTypeToEnum::value; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("neg"); - built_options.emplace("-Dneg=" + kernel_name); - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - kernel_ = runtime->BuildKernel("neg", kernel_name, built_options); - } - if (!IsVecEqual(input_shape_, input->shape())) { - uint32_t idx = 0; - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, *(output->opencl_image())); - input_shape_ = input->shape(); - } - - const uint32_t gws[3] = {static_cast(channel_blocks), - static_cast(width), - static_cast(height * batch)}; - const std::vector lws = {8, 16, 8}; - - cl::Event event; - cl_int error = runtime->command_queue().enqueueNDRangeKernel( - kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); - MACE_CHECK(error == CL_SUCCESS); - if (future != nullptr) { - future->wait_fn = [runtime, event](CallStats *stats) { - event.wait(); - if (stats != nullptr) { - runtime->GetCallStats(event, stats); - } - }; - } -} - -template struct NegFunctor; -template struct NegFunctor; -} // namespace kernels -} // namespace mace diff --git a/mace/ops/conv_2d.h b/mace/ops/conv_2d.h index cf58cc9c..528f1e1f 100644 --- a/mace/ops/conv_2d.h +++ b/mace/ops/conv_2d.h @@ -31,7 +31,6 @@ class Conv2dOp : public ConvPool2dOpBase { const Tensor *filter = this->Input(FILTER); const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr; Tensor *output = this->Output(OUTPUT); - functor_(input, filter, bias, output, future); return true; diff --git a/mace/ops/neg.cc b/mace/ops/cwise.cc similarity index 61% rename from mace/ops/neg.cc rename to mace/ops/cwise.cc index c4dffe70..42439c3e 100644 --- a/mace/ops/neg.cc +++ b/mace/ops/cwise.cc @@ -2,29 +2,29 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#include "mace/ops/neg.h" +#include "mace/ops/cwise.h" namespace mace { namespace ops { -void Register_Neg(OperatorRegistry *op_registry) { - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") +void Register_CWise(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") .Device(DeviceType::CPU) .TypeConstraint("T") .Build(), - NegOp); + CWiseOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), - NegOp); + CWiseOp); - REGISTER_OPERATOR(op_registry, OpKeyBuilder("Neg") + REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), - NegOp); + CWiseOp); } } // namespace ops diff --git a/mace/ops/scalar_math.h b/mace/ops/cwise.h similarity index 67% rename from mace/ops/scalar_math.h rename to mace/ops/cwise.h index 2f0f4394..75430183 100644 --- a/mace/ops/scalar_math.h +++ b/mace/ops/cwise.h @@ -2,27 +2,27 @@ // Copyright (c) 2017 XiaoMi All rights reserved. // -#ifndef MACE_OPS_SCALAR_MATH_H_ -#define MACE_OPS_SCALAR_MATH_H_ +#ifndef MACE_OPS_CWISE_H_ +#define MACE_OPS_CWISE_H_ #include #include "mace/core/operator.h" -#include "mace/kernels/scalar_math.h" +#include "mace/kernels/cwise.h" namespace mace { namespace ops { template -class ScalarMathOp : public Operator { +class CWiseOp : public Operator { public: - ScalarMathOp(const OperatorDef &operator_def, Workspace *ws) + CWiseOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), x_(OperatorBase::GetSingleArgument("x", 1.0)), - functor_(static_cast( + functor_(static_cast( OperatorBase::GetSingleArgument( "type", static_cast( - kernels::ScalarMathType::ADD))), + kernels::CWiseType::ADD))), this->x_) {} bool Run(StatsFuture *future) override { @@ -40,10 +40,10 @@ class ScalarMathOp : public Operator { OP_OUTPUT_TAGS(OUTPUT); private: - kernels::ScalarMathFunctor functor_; + kernels::CWiseFunctor functor_; }; } // namespace ops } // namespace mace -#endif // MACE_OPS_SCALAR_MATH_H_ +#endif // MACE_OPS_CWISE_H_ diff --git a/mace/ops/scalar_math_benchmark.cc b/mace/ops/cwise_benchmark.cc similarity index 59% rename from mace/ops/scalar_math_benchmark.cc rename to mace/ops/cwise_benchmark.cc index e01d39e0..6ab6aa54 100644 --- a/mace/ops/scalar_math_benchmark.cc +++ b/mace/ops/cwise_benchmark.cc @@ -12,7 +12,7 @@ namespace ops { namespace test { template -static void ScalarMath(int iters, int batch, int channels, +static void CWise(int iters, int batch, int channels, int height, int width, float x, int type) { mace::testing::StopTiming(); @@ -24,14 +24,14 @@ static void ScalarMath(int iters, int batch, int channels, if (D == DeviceType::OPENCL) { BufferToImage(&net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("ScalarMath", "ScalarMathBM") + OpDefBuilder("CWise", "CWiseBM") .Input("InputImage") .Output("Output") .AddIntArg("type", type) .AddFloatArg("x", x) .Finalize(net.NewOperatorDef()); } else { - OpDefBuilder("ScalarMath", "ScalarMathBM") + OpDefBuilder("CWise", "CWiseBM") .Input("Input") .Output("Output") .AddIntArg("type", type) @@ -52,35 +52,41 @@ static void ScalarMath(int iters, int batch, int channels, net.Sync(); } -#define BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ +#define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ static void \ - BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ + BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ int iters) { \ const int64_t tot = static_cast(iters) * N * C * H * W; \ mace::testing::MaccProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - ScalarMath(iters, N, C, H, W, X, G); \ + CWise(iters, N, C, H, W, X, G); \ } \ - BENCHMARK( \ - BM_SCALAR_MATH_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) + BENCHMARK( \ + BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) -#define BM_SCALAR_MATH(N, C, H, W, X, G) \ - BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, CPU); \ - BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, float, OPENCL); \ - BM_SCALAR_MATH_MACRO(N, C, H, W, X, G, half, OPENCL); +#define BM_CWISE(N, C, H, W, X, G) \ + BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \ + BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \ + BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL); -BM_SCALAR_MATH(1, 1, 512, 512, 2, 0); -BM_SCALAR_MATH(1, 3, 128, 128, 2, 1); -BM_SCALAR_MATH(1, 3, 512, 512, 2, 4); -BM_SCALAR_MATH(1, 32, 112, 112, 2, 5); -BM_SCALAR_MATH(1, 64, 256, 256, 3, 0); -BM_SCALAR_MATH(1, 64, 512, 512, 3, 1); -BM_SCALAR_MATH(1, 128, 56, 56, 3, 4); -BM_SCALAR_MATH(1, 128, 256, 256, 3, 5); -BM_SCALAR_MATH(1, 256, 14, 14, 3, 0); -BM_SCALAR_MATH(1, 512, 14, 14, 3, 1); -BM_SCALAR_MATH(1, 1024, 7, 7, 3, 4); -BM_SCALAR_MATH(32, 1, 256, 256, 3, 5); +BM_CWISE(1, 1, 512, 512, 2, 0); +BM_CWISE(1, 3, 128, 128, 2, 1); +BM_CWISE(1, 3, 512, 512, 2, 4); +BM_CWISE(1, 32, 112, 112, 2, 5); +BM_CWISE(1, 32, 112, 112, 2, 6); +BM_CWISE(1, 32, 112, 112, 2, 7); +BM_CWISE(1, 64, 256, 256, 3, 0); +BM_CWISE(1, 64, 512, 512, 3, 1); +BM_CWISE(1, 128, 56, 56, 3, 4); +BM_CWISE(1, 128, 256, 256, 3, 5); +BM_CWISE(1, 64, 512, 512, 3, 6); +BM_CWISE(1, 64, 512, 512, 3, 7); +BM_CWISE(1, 256, 14, 14, 3, 0); +BM_CWISE(1, 512, 14, 14, 3, 1); +BM_CWISE(1, 1024, 7, 7, 3, 4); +BM_CWISE(32, 1, 256, 256, 3, 5); +BM_CWISE(32, 1, 256, 256, 3, 6); +BM_CWISE(32, 1, 256, 256, 3, 7); } // namespace test } // namespace ops diff --git a/mace/ops/scalar_math_test.cc b/mace/ops/cwise_test.cc similarity index 61% rename from mace/ops/scalar_math_test.cc rename to mace/ops/cwise_test.cc index 3da6176d..7bd934f8 100644 --- a/mace/ops/scalar_math_test.cc +++ b/mace/ops/cwise_test.cc @@ -4,17 +4,17 @@ #include "mace/core/operator.h" #include "mace/ops/ops_test_util.h" -#include "../kernels/scalar_math.h" +#include "../kernels/cwise.h" namespace mace { namespace ops { namespace test { -class ScalarMathOpTest : public OpsTestBase {}; +class CWiseOpTest : public OpsTestBase {}; template -void Simple(const kernels::ScalarMathType type, +void Simple(const kernels::CWiseType type, const std::vector &shape, const std::vector &input0, const float x, @@ -26,7 +26,7 @@ void Simple(const kernels::ScalarMathType type, net.AddInputFromArray("Input1", shape, input0); if (D == DeviceType::CPU) { - OpDefBuilder("ScalarMath", "ScalarMathTest") + OpDefBuilder("CWise", "CWiseTest") .Input("Input1") .AddIntArg("type", static_cast(type)) .AddFloatArg("x", x) @@ -38,7 +38,7 @@ void Simple(const kernels::ScalarMathType type, } else { BufferToImage(&net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("ScalarMath", "ScalarMathTest") + OpDefBuilder("CWise", "CWiseTest") .Input("InputImg1") .AddIntArg("type", static_cast(type)) .AddFloatArg("x", x) @@ -57,36 +57,48 @@ void Simple(const kernels::ScalarMathType type, ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-3); } -TEST_F(ScalarMathOpTest, CPUSimple) { - Simple(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, +TEST_F(CWiseOpTest, CPUSimple) { + Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); - Simple(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, + Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); - Simple(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, + Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); - Simple(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, + Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); + + Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); + + Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, + {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); } -TEST_F(ScalarMathOpTest, GPUSimple) { - Simple(kernels::ScalarMathType::MUL, {1, 1, 2, 3}, +TEST_F(CWiseOpTest, GPUSimple) { + Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); - Simple(kernels::ScalarMathType::ADD, {1, 1, 2, 3}, + Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); - Simple(kernels::ScalarMathType::DIV, {1, 1, 2, 3}, + Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); - Simple(kernels::ScalarMathType::SUB, {1, 1, 2, 3}, + Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); + + Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); + + Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, + {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); } template -void RandomTest(const kernels::ScalarMathType type, +void RandomTest(const kernels::CWiseType type, const std::vector &shape) { testing::internal::LogToStderr(); srand(time(NULL)); @@ -97,7 +109,7 @@ void RandomTest(const kernels::ScalarMathType type, // Add input data net.AddRandomInput("Input1", shape); - OpDefBuilder("ScalarMath", "ScalarMathTest") + OpDefBuilder("CWise", "CWiseTest") .Input("Input1") .AddIntArg("type", static_cast(type)) .AddFloatArg("x", 1.2) @@ -110,7 +122,7 @@ void RandomTest(const kernels::ScalarMathType type, BufferToImage(&net, "Input1", "InputImg1", kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("ScalarMath", "ScalarMathTest") + OpDefBuilder("CWise", "CWiseTest") .Input("InputImg1") .AddIntArg("type", static_cast(type)) .AddFloatArg("x", 1.2) @@ -133,25 +145,29 @@ void RandomTest(const kernels::ScalarMathType type, } } -TEST_F(ScalarMathOpTest, OPENCLRandomFloat) { - RandomTest(kernels::ScalarMathType::MUL, +TEST_F(CWiseOpTest, OPENCLRandomFloat) { + RandomTest(kernels::CWiseType::MUL, {3, 23, 37, 19}); - RandomTest(kernels::ScalarMathType::ADD, + RandomTest(kernels::CWiseType::ADD, {13, 32, 32, 64}); - RandomTest(kernels::ScalarMathType::SUB, + RandomTest(kernels::CWiseType::SUB, {3, 32, 32, 64}); - RandomTest(kernels::ScalarMathType::DIV, + RandomTest(kernels::CWiseType::DIV, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::NEG, {13, 32, 32, 64}); } -TEST_F(ScalarMathOpTest, OPENCLRandomHalf) { - RandomTest(kernels::ScalarMathType::MUL, +TEST_F(CWiseOpTest, OPENCLRandomHalf) { + RandomTest(kernels::CWiseType::MUL, {3, 23, 37, 19}); - RandomTest(kernels::ScalarMathType::ADD, + RandomTest(kernels::CWiseType::ADD, {13, 32, 32, 64}); - RandomTest(kernels::ScalarMathType::SUB, + RandomTest(kernels::CWiseType::SUB, {3, 32, 32, 64}); - RandomTest(kernels::ScalarMathType::DIV, + RandomTest(kernels::CWiseType::DIV, + {13, 32, 32, 64}); + RandomTest(kernels::CWiseType::NEG, {13, 32, 32, 64}); } diff --git a/mace/ops/neg.h b/mace/ops/neg.h deleted file mode 100644 index 0e3be04c..00000000 --- a/mace/ops/neg.h +++ /dev/null @@ -1,39 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#ifndef MACE_OPS_NEG_H_ -#define MACE_OPS_NEG_H_ - -#include - -#include "mace/core/operator.h" -#include "mace/kernels/negative.h" - -namespace mace { -namespace ops { - -template -class NegOp : public Operator { - public: - NegOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), - functor_() {} - - bool Run(StatsFuture *future) override { - const Tensor *input_tensor = this->Input(0); - Tensor *output_tensor = this->outputs_[0]; - output_tensor->ResizeLike(input_tensor); - - functor_(input_tensor, output_tensor, future); - return true; - } - - private: - kernels::NegFunctor functor_; -}; - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_NEGATIVE_H_ diff --git a/mace/ops/neg_benchmark.cc b/mace/ops/neg_benchmark.cc deleted file mode 100644 index 4d8f6cce..00000000 --- a/mace/ops/neg_benchmark.cc +++ /dev/null @@ -1,82 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/operator.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/core/testing/test_benchmark.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -template -static void Neg(int iters, int batch, int channels, int height, int width) { - mace::testing::StopTiming(); - - OpsTestNet net; - - // Add input data - net.AddRandomInput("Input", {batch, height, width, channels}); - - if (D == DeviceType::OPENCL) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("Neg", "NegBM") - .Input("InputImage") - .Output("Output") - .Finalize(net.NewOperatorDef()); - } else { - OpDefBuilder("Neg", "NegBM") - .Input("Input") - .Output("Output") - .Finalize(net.NewOperatorDef()); - } - - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); - } - net.Sync(); - - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } - net.Sync(); -} - -#define BM_NEG_MACRO(N, C, H, W, TYPE, DEVICE) \ - static void BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - Neg(iters, N, C, H, W); \ - } \ - BENCHMARK(BM_NEG_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) - -#define BM_NEG(N, C, H, W) \ - BM_NEG_MACRO(N, C, H, W, float, CPU); \ - BM_NEG_MACRO(N, C, H, W, float, OPENCL); \ - BM_NEG_MACRO(N, C, H, W, half, OPENCL); - -BM_NEG(1, 1, 512, 512); -BM_NEG(1, 3, 128, 128); -BM_NEG(1, 3, 512, 512); -BM_NEG(1, 32, 112, 112); -BM_NEG(1, 64, 256, 256); -BM_NEG(1, 64, 512, 512); -BM_NEG(1, 128, 56, 56); -BM_NEG(1, 128, 256, 256); -BM_NEG(1, 256, 14, 14); -BM_NEG(1, 512, 14, 14); -BM_NEG(1, 1024, 7, 7); -BM_NEG(32, 1, 256, 256); -BM_NEG(32, 3, 256, 256); - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/neg_test.cc b/mace/ops/neg_test.cc deleted file mode 100644 index c7ae15f6..00000000 --- a/mace/ops/neg_test.cc +++ /dev/null @@ -1,61 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/core/operator.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -class NegOpTest : public OpsTestBase {}; - -template -void NegSimple() { - OpsTestNet net; - - // Add input data - net.AddInputFromArray("Input", {1, 6, 2, 1}, - {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); - - if (D == DeviceType::OPENCL) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("Neg", "NegTest") - .Input("InputImage") - .Output("OutputImage") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - - // Transfer output - ImageToBuffer(&net, "OutputImage", "Output", - kernels::BufferType::IN_OUT_CHANNEL); - } else { - OpDefBuilder("Neg", "NegTest") - .Input("Input") - .Output("Output") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - } - - // Check - auto expected = CreateTensor( - {1, 6, 2, 1}, - {-5, -5, -7, -7, -9, -9, -11, -11, -13, -13, -15, -15}); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-8); -} - -TEST_F(NegOpTest, NegSimpleCPU) { NegSimple(); } - -TEST_F(NegOpTest, NegSimpleOPENCL) { - NegSimple(); -} - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/scalar_math.cc b/mace/ops/scalar_math.cc deleted file mode 100644 index 9891994f..00000000 --- a/mace/ops/scalar_math.cc +++ /dev/null @@ -1,31 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "mace/ops/scalar_math.h" - -namespace mace { -namespace ops { - -void Register_ScalarMath(OperatorRegistry *op_registry) { - REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - ScalarMathOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - ScalarMathOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("ScalarMath") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - ScalarMathOp); -} - -} // namespace ops -} // namespace mace diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 01a4a540..93fbea92 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -19,14 +19,19 @@ pooling_type_mode = { 'MaxPool': 2 } -# the order should be the same as eltwise type's order +# the order should be the same as +# eltwise type's in mace/kernels/eltwise.h +# and also cwise type's in mace/kernels/cwise.h +# cuz these math ops should have compatible with "EltWise" and "CWise" math_type_mode = { 'MUL': 0, 'ADD': 1, 'MAX': 2, 'MIN': 3, 'SUB': 4, - 'DIV': 5 + 'DIV': 5, + 'NEG': 6, + 'ABS': 7 } buffer_type_map = { @@ -632,18 +637,6 @@ class TFConverter(object): self.add_output_shape(op.outputs, op_def) self.resolved_ops[op.name] = 1 self.unused_tensor.add(get_input_tensor(op, 1).name) - - def convert_neg(self, op): - op_def = self.net_def.op.add() - arg = op_def.arg.add() - arg.name = 'T' - arg.i = self.dt - op_def.name = op.name - op_def.type = "Neg" - op_def.input.extend([input.name for input in op.inputs]) - op_def.output.extend([output.name for output in op.outputs]) - self.add_output_shape(op.outputs, op_def) - self.resolved_ops[op.name] = 1 def convert_math(self, op, math_type): op_def = self.net_def.op.add() @@ -651,24 +644,31 @@ class TFConverter(object): arg.name = 'T' arg.i = self.dt op_def.name = op.name - input_tensor0 = get_input_tensor(op, 0) - input_tensor1 = get_input_tensor(op, 1) - - if input_tensor0.shape == input_tensor1.shape: - op_def.type = "Eltwise" + + if len(op.inputs) == 1: + op_def.type = "CWise" op_def.input.extend([input.name for input in op.inputs]) - else: - op_def.type = "ScalarMath" - x_value = 0 - if len(input_tensor1.shape)==4: - op_def.input.extend([op.inputs[1].name]) - x_value = get_input_tensor(op, 0).eval().astype(np.float32) - else: - op_def.input.extend([op.inputs[0].name]) - x_value = get_input_tensor(op, 1).eval().astype(np.float32) x_arg = op_def.arg.add() x_arg.name = 'x' - x_arg.f = x_value + x_arg.f = 0 + elif len(op.inputs) >= 2: + input_tensor0 = get_input_tensor(op, 0) + input_tensor1 = get_input_tensor(op, 1) + if input_tensor0.shape == input_tensor1.shape: + op_def.type = "Eltwise" + op_def.input.extend([input.name for input in op.inputs]) + else: + op_def.type = "CWise" + x_value = 0 + if len(input_tensor1.shape)==4: + op_def.input.extend([op.inputs[1].name]) + x_value = get_input_tensor(op, 0).eval().astype(np.float32) + else: + op_def.input.extend([op.inputs[0].name]) + x_value = get_input_tensor(op, 1).eval().astype(np.float32) + x_arg = op_def.arg.add() + x_arg.name = 'x' + x_arg.f = x_value type_arg = op_def.arg.add() type_arg.name = 'type' type_arg.i = math_type_mode[math_type] @@ -919,15 +919,15 @@ class TFConverter(object): elif op.type == 'BatchToSpaceND': self.convert_space_to_batch(op, True) elif op.type == 'DepthToSpace': - self.convert_depth_to_space(op, True) + self.convert_depth_to_space(op, True) elif op.type == 'SpaceToDepth': - self.convert_depth_to_space(op, False) - elif op.type == 'Neg': - self.convert_neg(op) + self.convert_depth_to_space(op, False) + elif op.type in ['Neg', 'neg', 'Negative', 'negative']: + self.convert_math(op, 'NEG') elif op.type == 'Mul': - self.convert_math(op, 'MUL') + self.convert_math(op, 'MUL') elif op.type == 'Sub': - self.convert_math(op, 'SUB') + self.convert_math(op, 'SUB') elif self.is_softmax(op): self.convert_softmax(op) elif op.type in ['Relu', 'Sigmoid', 'Tanh']: -- GitLab