diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 68de7908587c9324df2b38848145f8fe63fcc152..ad5cccf2a013df7b383cf5849cbaf2c92efe1cbc 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -83,7 +83,6 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); -extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Dequantize(OperatorRegistry *op_registry); @@ -123,7 +122,6 @@ OperatorRegistry::OperatorRegistry() { ops::Register_ChannelShuffle(this); ops::Register_Concat(this); ops::Register_Conv2D(this); - ops::Register_CWise(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); ops::Register_Dequantize(this); diff --git a/mace/kernels/cwise.h b/mace/kernels/cwise.h deleted file mode 100644 index 997410a36829b86a95b21451f1846892a28c27b8..0000000000000000000000000000000000000000 --- a/mace/kernels/cwise.h +++ /dev/null @@ -1,137 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_KERNELS_CWISE_H_ -#define MACE_KERNELS_CWISE_H_ - -#include -#include -#include - -#include "mace/core/future.h" -#include "mace/core/runtime/opencl/cl2_header.h" -#include "mace/core/tensor.h" - -namespace mace { -namespace kernels { - -enum CWiseType { - MUL = 0, - ADD = 1, - MAX = 2, - MIN = 3, - SUB = 4, - DIV = 5, - NEG = 6, - ABS = 7, -}; - -struct CWiseFunctorBase { - CWiseFunctorBase(const CWiseType type, const float coeff) - : type_(type), coeff_(coeff) {} - - CWiseType type_; - float coeff_; -}; - -template -struct CWiseFunctor : CWiseFunctorBase { - CWiseFunctor(const CWiseType type, const float coeff) - : CWiseFunctorBase(type, coeff) {} - - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - Tensor::MappingGuard input_guard(input); - Tensor::MappingGuard output_guard(output); - - const T *input_ptr = input->data(); - T *output_ptr = output->mutable_data(); - const index_t size = input->size(); - - switch (type_) { - case MUL: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = coeff_ * input_ptr[i]; - } - break; - case ADD: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = coeff_ + input_ptr[i]; - } - break; - case MAX: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::max(input_ptr[i], coeff_); - } - break; - case MIN: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::min(input_ptr[i], coeff_); - } - break; - case SUB: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input_ptr[i] - coeff_; - } - break; - case DIV: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input_ptr[i] / coeff_; - } - break; - case NEG: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = 0 - input_ptr[i]; - } - break; - case ABS: -#pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - T val = input_ptr[i]; - output_ptr[i] = (val > 0)? val : 0 - val; - } - break; - default: - LOG(FATAL) << "CWise op not support type " << type_; - } - } -}; - -template -struct CWiseFunctor : CWiseFunctorBase { - CWiseFunctor(const CWiseType type, const float coeff) - : CWiseFunctorBase(type, coeff) {} - - void operator()(const Tensor *input, - Tensor *output, - StatsFuture *future); - - cl::Kernel kernel_; - uint32_t kwg_size_; - std::unique_ptr kernel_error_; - std::vector input_shape_; -}; - -} // namespace kernels -} // namespace mace - -#endif // MACE_KERNELS_CWISE_H_ diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 94c37cdb4ae1cedde48a6504331ecead3c01caee..672e03ddedfc46811d8e317b5accd0354af75b7b 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -32,10 +32,15 @@ enum EltwiseType { MAX = 2, MIN = 3, SUB = 4, + DIV = 5, + NEG = 6, + ABS = 7, + SQR_DIFF = 8, }; struct EltwiseFunctorBase { - EltwiseFunctorBase(const EltwiseType type, const std::vector &coeff) + EltwiseFunctorBase(const EltwiseType type, + const std::vector &coeff) : type_(type), coeff_(coeff) {} EltwiseType type_; @@ -44,74 +49,211 @@ struct EltwiseFunctorBase { template struct EltwiseFunctor : EltwiseFunctorBase { - EltwiseFunctor(const EltwiseType type, const std::vector &coeff) + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) : EltwiseFunctorBase(type, coeff) {} void operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future) { - Tensor::MappingGuard input0_guard(input0); - Tensor::MappingGuard input1_guard(input1); - Tensor::MappingGuard output_guard(output); + if (is_scaler) { + Tensor::MappingGuard input0_guard(input0); + Tensor::MappingGuard output_guard(output); - const T *input0_ptr = input0->data(); - const T *input1_ptr = input1->data(); - T *output_ptr = output->mutable_data(); - const index_t size = input0->size(); - - switch (type_) { - case PROD: + const T *input0_ptr = input0->data(); + T *output_ptr = output->mutable_data(); + const index_t num = input0->size(); + switch (type_) { + case PROD: +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] * value; + } + break; + case SUM: + if (coeff_.empty()) { #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] * input1_ptr[i]; - } - break; - case SUM: - if (coeff_.empty()) { + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] + value; + } + } else { + const float coeff_0 = swap ? coeff_[1] : coeff_[0]; + const float coeff_1 = swap ? coeff_[0] : coeff_[1]; #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] + input1_ptr[i]; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = coeff_0 * input0_ptr[i] + + coeff_1 * value; + } } - } else { + break; + case MAX: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = - coeff_[0] * input0_ptr[i] + coeff_[1] * input1_ptr[i]; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = std::max(input0_ptr[i], value); } - } - break; - case MAX: + break; + case MIN: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::max(input0_ptr[i], input1_ptr[i]); - } - break; - case MIN: + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = std::min(input0_ptr[i], value); + } + break; + case SUB: #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = std::min(input0_ptr[i], input1_ptr[i]); - } - break; - case SUB: + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = swap ? value - input0_ptr[i] : + input0_ptr[i] - value; + } + break; + case DIV: + if (!swap) { + MACE_CHECK(fabs(value) > 1e-6, "cannot divided by 0."); #pragma omp parallel for - for (index_t i = 0; i < size; ++i) { - output_ptr[i] = input0_ptr[i] - input1_ptr[i]; - } - break; - default: - LOG(FATAL) << "Eltwise op not support type " << type_; + for (index_t i = 0; i < num; ++i) { + output_ptr[i] = input0_ptr[i] / value; + } + } else { +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + MACE_CHECK(fabs(input0_ptr[i]) > 1e-6, "cannot divided by 0."); + output_ptr[i] = value / input0_ptr[i]; + } + } + break; + case SQR_DIFF: +#pragma omp parallel for + for (index_t i = 0; i < num; ++i) { + const float tmp = input0_ptr[i] - value; + output_ptr[i] = tmp * tmp; + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type_; + } + } else { + MACE_CHECK_NOTNULL(input0); + MACE_CHECK_NOTNULL(input1); + Tensor::MappingGuard input0_guard(input0); + Tensor::MappingGuard input1_guard(input1); + Tensor::MappingGuard output_guard(output); + + const T *input0_ptr = input0->data(); + const T *input1_ptr = input1->data(); + T *output_ptr = output->mutable_data(); + const index_t size0 = input0->size(); + const index_t size1 = input1->size(); + + const index_t num = size0 / size1; + switch (type_) { + case PROD: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j= 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] * input1_ptr[j]; + } + } + break; + case SUM: + if (coeff_.empty()) { +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] + input1_ptr[j]; + } + } + } else { + const float coeff_0 = swap ? coeff_[1] : coeff_[0]; + const float coeff_1 = swap ? coeff_[0] : coeff_[1]; +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + coeff_0 * input0_ptr[i * size1 + j] + + coeff_1 * input1_ptr[j]; + } + } + } + break; + case MAX: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + std::max(input0_ptr[i * size1 + j], input1_ptr[j]); + } + } + break; + case MIN: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = + std::min(input0_ptr[i * size1 + j], input1_ptr[j]); + } + } + break; + case SUB: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + output_ptr[i * size1 + j] = swap ? + input0_ptr[i * size1 + j] - input1_ptr[j] : + input1_ptr[j] - input0_ptr[i * size1 + j]; + } + } + break; + case DIV: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + if (!swap) { + MACE_CHECK(fabs(input1_ptr[j]) > 1e-6, "cannot divided by 0."); + output_ptr[i * size1 + j] = + input0_ptr[i * size1 + j] / input1_ptr[j]; + } else { + MACE_CHECK(fabs(input0_ptr[i * size1 + j]) > 1e-6, + "cannot divided by 0."); + output_ptr[i * size1 + j] = + input1_ptr[j] / input0_ptr[i * size1 + j]; + } + } + } + break; + case SQR_DIFF: +#pragma omp parallel for collapse(2) + for (index_t i = 0; i < num; ++i) { + for (index_t j = 0; j < size1; ++j) { + const T tmp = input0_ptr[i * size1 + j] - input1_ptr[j]; + output_ptr[i * size1 + j] = tmp * tmp; + } + } + break; + default: + LOG(FATAL) << "Eltwise op not support type " << type_; + } } } }; template struct EltwiseFunctor : EltwiseFunctorBase { - EltwiseFunctor(const EltwiseType type, const std::vector &coeff) + EltwiseFunctor(const EltwiseType type, + const std::vector &coeff) : EltwiseFunctorBase(type, coeff) {} void operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future); diff --git a/mace/kernels/opencl/cl/cwise.cl b/mace/kernels/opencl/cl/cwise.cl deleted file mode 100644 index 2d3f3105cbddb0dfd9d8b3b208bf400772f60fb4..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/cl/cwise.cl +++ /dev/null @@ -1,56 +0,0 @@ -#include - -__kernel void cwise(KERNEL_ERROR_PARAMS - GLOBAL_WORK_GROUP_SIZE_DIM2 - __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ - __private const int width, - __private const int channel, - __private const float value, - __write_only image2d_t output) { - const int w = get_global_id(0); - const int hb = get_global_id(1); - -#ifndef NON_UNIFORM_WORK_GROUP - if (w >= global_size_dim0 || hb >= global_size_dim1) return; -#endif - - const int remain_chan = channel - mul24((w / width), 4); - - DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); - DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value}; - DATA_TYPE4 out; - -#if CWISE_TYPE == 0 - out = in0 * in1; -#elif CWISE_TYPE == 1 - out = in0 + in1; -#elif CWISE_TYPE == 2 - out = fmax(in0, in1); -#elif CWISE_TYPE == 3 - out = fmin(in0, in1); -#elif CWISE_TYPE == 4 - out = in0 - in1; -#elif CWISE_TYPE == 5 - out = in0 / in1; -#elif CWISE_TYPE == 6 - in1 = (DATA_TYPE4)(0, 0, 0, 0); - out = in1 - in0; -#elif CWISE_TYPE == 7 - out = fabs(in0); -#endif - -#if CWISE_TYPE == 1 || CWISE_TYPE == 2 || CWISE_TYPE == 3 || CWISE_TYPE == 4 - if (remain_chan < 4) { - switch (remain_chan) { - case 1: - out.y = 0; - case 2: - out.z = 0; - case 3: - out.w = 0; - } - } -#endif - - WRITE_IMAGET(output, (int2)(w, hb), out); -} diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index 58838a7d29aad87345706cb66ecea0d86d4c22a4..b2ebebeccd93c43c686f98f83997152305dd8a1f 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -1,30 +1,62 @@ #include __kernel void eltwise(KERNEL_ERROR_PARAMS - GLOBAL_WORK_GROUP_SIZE_DIM2 - __read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */ + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input0, __read_only image2d_t input1, + __private const float value, + __private const int height, + __private const int width, + __private const int channel, #ifdef COEFF_SUM __private const float coeff0, __private const float coeff1, #endif __write_only image2d_t output) { - const int w = get_global_id(0); - const int hb = get_global_id(1); + const int c = get_global_id(0); + const int w = get_global_id(1); + const int hb = get_global_id(2); #ifndef NON_UNIFORM_WORK_GROUP - if (w >= global_size_dim0 || hb >= global_size_dim1) return; + if (c >= global_size_dim0 || w >= global_size_dim1 || hb >= global_size_dim2) + return; #endif - DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(w, hb)); - DATA_TYPE4 in1 = READ_IMAGET(input1, SAMPLER, (int2)(w, hb)); + int pos_w; + int pos_h; +#if START_AXIS == 0 + pos_w = mad24(c, width, w); + pos_h = hb; +#elif START_AXIS == 1 + pos_w = mad24(c, width, w); + pos_h = hb % height; +#elif START_AXIS == 2 + pos_w = mad24(c, width, w); + pos_h = 0; +#elif START_AXIS == 3 + pos_w = c; + pos_h = 0; +#endif + const int pos = mad24(c, width, w); + const int remain_channel = channel - 4 * c; + DATA_TYPE4 in0 = READ_IMAGET(input0, SAMPLER, (int2)(pos, hb)); + DATA_TYPE4 in1 ; +#if IS_SCALER == 1 + in1 = (DATA_TYPE4){value, value, value, value}; +#else + in1 = READ_IMAGET(input1, SAMPLER, (int2)(pos_w, pos_h)); +#endif DATA_TYPE4 out; #if ELTWISE_TYPE == 0 out = in0 * in1; #elif ELTWISE_TYPE == 1 #ifdef COEFF_SUM - out = mad(coeff0, in0, mad(coeff1, in1, 0)); + #if NEEDSWAP == 0 + out = mad(coeff0, in0, mad(coeff1, in1, 0)); + #else + out = mad(coeff1, in0, mad(coeff0, in1, 0)); + #endif #else out = in0 + in1; #endif @@ -34,8 +66,49 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS #elif ELTWISE_TYPE == 3 out = fmin(in0, in1); #elif ELTWISE_TYPE == 4 - out = in0 - in1; + #if NEED_SWAP == 0 + out = in0 - in1; + #else + out = in1 - in0; + #endif +#elif ELTWISE_TYPE == 5 + #if NEED_SWAP == 0 + if (fabs(in1.x) > 0.000001f) + out.x = in0.x / in1.x; + if (fabs(in1.y) > 0.000001f) + out.y = in0.y / in1.y; + if (fabs(in1.z) > 0.000001f) + out.z = in0.z / in1.z; + if (fabs(in1.w) > 0.000001f) + out.w = in0.w / in1.w; + #else + if (fabs(in1.x) > 0.000001f) + out.x = in1.x / in0.x; + if (fabs(in1.y) > 0.000001f) + out.y = in1.y / in0.y; + if (fabs(in1.z) > 0.000001f) + out.z = in1.z / in0.z; + if (fabs(in1.w) > 0.000001f) + out.w = in1.w / in0.w; + #endif +#elif ELTWISE_TYPE == 8 + DATA_TYPE4 diff = in0 - in1; + out = diff * diff; +#endif + +#if ELTWISE_TYPE == 1 || ELTWISE_TYPE == 2 || ELTWISE_TYPE == 3 \ + || ELTWISE_TYPE == 4 || ELTWISE_TYPE == 8 + if (remain_channel < 4) { + switch (remain_channel) { + case 1: + out.y = 0; + case 2: + out.z = 0; + case 3: + out.w = 0; + } + } #endif - WRITE_IMAGET(output, (int2)(w, hb), out); + WRITE_IMAGET(output, (int2)(pos, hb), out); } diff --git a/mace/kernels/opencl/cwise_opencl.cc b/mace/kernels/opencl/cwise_opencl.cc deleted file mode 100644 index a9565a3d41c41a6f1d1975c6c744aafa5eb5a6e8..0000000000000000000000000000000000000000 --- a/mace/kernels/opencl/cwise_opencl.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/kernels/cwise.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/kernels/opencl/helper.h" -#include "mace/utils/tuner.h" - -namespace mace { -namespace kernels { - -template -void CWiseFunctor::operator()(const Tensor *input, - Tensor *output, - StatsFuture *future) { - const index_t batch = input->dim(0); - const index_t height = input->dim(1); - const index_t width = input->dim(2); - const index_t channels = input->dim(3); - - const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_pixels = channel_blocks * width; - const index_t batch_height_pixels = batch * height; - - auto runtime = OpenCLRuntime::Global(); - const uint32_t gws[2] = {static_cast(width_pixels), - static_cast(batch_height_pixels)}; - if (kernel_.get() == nullptr) { - std::set built_options; - auto dt = DataTypeToEnum::value; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise"); - built_options.emplace("-Dcwise=" + kernel_name); - built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); - built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); - built_options.emplace(MakeString("-DCWISE_TYPE=", type_)); - if (runtime->IsOutOfRangeCheckEnabled()) { - built_options.emplace("-DOUT_OF_RANGE_CHECK"); - kernel_error_ = std::move(std::unique_ptr( - new Buffer(GetDeviceAllocator(DeviceType::OPENCL), 1))); - kernel_error_->Map(nullptr); - *(kernel_error_->mutable_data()) = 0; - kernel_error_->UnMap(); - } - if (runtime->IsNonUniformWorkgroupsSupported()) { - built_options.emplace("-DNON_UNIFORM_WORK_GROUP"); - } - kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options); - - kwg_size_ = - static_cast(runtime->GetKernelMaxWorkGroupSize(kernel_)); - } - if (!IsVecEqual(input_shape_, input->shape())) { - uint32_t idx = 0; - if (runtime->IsOutOfRangeCheckEnabled()) { - kernel_.setArg(idx++, - *(static_cast(kernel_error_->buffer()))); - } - if (!runtime->IsNonUniformWorkgroupsSupported()) { - kernel_.setArg(idx++, gws[0]); - kernel_.setArg(idx++, gws[1]); - } - kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, static_cast(width)); - kernel_.setArg(idx++, static_cast(channels)); - kernel_.setArg(idx++, static_cast(coeff_)); - kernel_.setArg(idx++, *(output->opencl_image())); - input_shape_ = input->shape(); - } - - const std::vector lws = {kwg_size_ / 16, 16, 0}; - std::stringstream ss; - ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) - << "_" << output->dim(2) << "_" << output->dim(3); - TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); - - if (runtime->IsOutOfRangeCheckEnabled()) { - kernel_error_->Map(nullptr); - char *kerror_code = kernel_error_->mutable_data(); - MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code; - kernel_error_->UnMap(); - } -} - -template struct CWiseFunctor; -template struct CWiseFunctor; -} // namespace kernels -} // namespace mace diff --git a/mace/kernels/opencl/eltwise_opencl.cc b/mace/kernels/opencl/eltwise_opencl.cc index 629ba89045b043f2b1f7965eefc875a32a78b8ae..0ec4a1e59e4925f8e94e1221360df9d1ac38fd50 100644 --- a/mace/kernels/opencl/eltwise_opencl.cc +++ b/mace/kernels/opencl/eltwise_opencl.cc @@ -23,6 +23,10 @@ namespace kernels { template void EltwiseFunctor::operator()(const Tensor *input0, const Tensor *input1, + const index_t start_axis, + const bool is_scaler, + const float value, + const bool swap, Tensor *output, StatsFuture *future) { const index_t batch = input0->dim(0); @@ -31,14 +35,15 @@ void EltwiseFunctor::operator()(const Tensor *input0, const index_t channels = input0->dim(3); const index_t channel_blocks = RoundUpDiv4(channels); - const index_t width_pixels = channel_blocks * width; const index_t batch_height_pixels = batch * height; - const uint32_t gws[2] = {static_cast(width_pixels), + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(width), static_cast(batch_height_pixels)}; + const int scaler = is_scaler ? 1 : 0; + const int need_swap = swap ? 1 : 0; auto runtime = OpenCLRuntime::Global(); - if (kernel_.get() == nullptr) { std::set built_options; auto dt = DataTypeToEnum::value; @@ -47,6 +52,9 @@ void EltwiseFunctor::operator()(const Tensor *input0, built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace(MakeString("-DELTWISE_TYPE=", type_)); + built_options.emplace(MakeString("-DSTART_AXIS=", start_axis)); + built_options.emplace(MakeString("-DIS_SCALER=", scaler)); + built_options.emplace(MakeString("-DNEEDSWAP=", need_swap)); if (runtime->IsOutOfRangeCheckEnabled()) { built_options.emplace("-DOUT_OF_RANGE_CHECK"); kernel_error_ = std::move(std::unique_ptr( @@ -73,9 +81,14 @@ void EltwiseFunctor::operator()(const Tensor *input0, if (!runtime->IsNonUniformWorkgroupsSupported()) { kernel_.setArg(idx++, gws[0]); kernel_.setArg(idx++, gws[1]); + kernel_.setArg(idx++, gws[2]); } kernel_.setArg(idx++, *(input0->opencl_image())); kernel_.setArg(idx++, *(input1->opencl_image())); + kernel_.setArg(idx++, value); + kernel_.setArg(idx++, static_cast(height)); + kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, static_cast(channels)); if (!coeff_.empty()) { kernel_.setArg(idx++, coeff_[0]); kernel_.setArg(idx++, coeff_[1]); @@ -85,11 +98,11 @@ void EltwiseFunctor::operator()(const Tensor *input0, input_shape_ = input0->shape(); } - const std::vector lws = {kwg_size_ / 16, 16, 0}; + const std::vector lws = {8, kwg_size_ / 64, 8, 0}; std::stringstream ss; ss << "eltwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1) << "_" << output->dim(2) << "_" << output->dim(3); - TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future); + TuningOrRun3DKernel(kernel_, ss.str(), gws, lws, future); if (runtime->IsOutOfRangeCheckEnabled()) { kernel_error_->Map(nullptr); char *kerror_code = kernel_error_->mutable_data(); diff --git a/mace/ops/cwise.cc b/mace/ops/cwise.cc deleted file mode 100644 index 4f98b63dc4c161f4f807c7d1149c3f7bd7410722..0000000000000000000000000000000000000000 --- a/mace/ops/cwise.cc +++ /dev/null @@ -1,41 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/ops/cwise.h" - -namespace mace { -namespace ops { - -void Register_CWise(OperatorRegistry *op_registry) { - REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") - .Device(DeviceType::CPU) - .TypeConstraint("T") - .Build(), - CWiseOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - CWiseOp); - - REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise") - .Device(DeviceType::OPENCL) - .TypeConstraint("T") - .Build(), - CWiseOp); -} - -} // namespace ops -} // namespace mace diff --git a/mace/ops/cwise.h b/mace/ops/cwise.h deleted file mode 100644 index 8cef0e10814cc35cbe09677c606c406aceeefc21..0000000000000000000000000000000000000000 --- a/mace/ops/cwise.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MACE_OPS_CWISE_H_ -#define MACE_OPS_CWISE_H_ - -#include - -#include "mace/core/operator.h" -#include "mace/kernels/cwise.h" - -namespace mace { -namespace ops { - -template -class CWiseOp : public Operator { - public: - CWiseOp(const OperatorDef &operator_def, Workspace *ws) - : Operator(operator_def, ws), - x_(OperatorBase::GetSingleArgument("x", 1.0)), - functor_(static_cast( - OperatorBase::GetSingleArgument( - "type", static_cast( - kernels::CWiseType::ADD))), - this->x_) {} - - bool Run(StatsFuture *future) override { - const Tensor *input_tensor = this->Input(INPUT); - Tensor *output_tensor = this->Output(OUTPUT); - output_tensor->ResizeLike(input_tensor); - - functor_(input_tensor, output_tensor, future); - return true; - } - - protected: - const float x_; - OP_INPUT_TAGS(INPUT); - OP_OUTPUT_TAGS(OUTPUT); - - private: - kernels::CWiseFunctor functor_; -}; - -} // namespace ops -} // namespace mace - -#endif // MACE_OPS_CWISE_H_ diff --git a/mace/ops/cwise_benchmark.cc b/mace/ops/cwise_benchmark.cc deleted file mode 100644 index 8d41d85dd5a9704490a17bf0e446a9691bf62b5c..0000000000000000000000000000000000000000 --- a/mace/ops/cwise_benchmark.cc +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/core/operator.h" -#include "mace/core/runtime/opencl/opencl_runtime.h" -#include "mace/core/testing/test_benchmark.h" -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -namespace { -template -void CWise(int iters, int batch, int channels, - int height, int width, float x, int type) { - mace::testing::StopTiming(); - - OpsTestNet net; - - // Add input data - net.AddRandomInput("Input", {batch, height, width, channels}); - - if (D == DeviceType::OPENCL) { - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("CWise", "CWiseBM") - .Input("InputImage") - .Output("Output") - .AddIntArg("type", type) - .AddFloatArg("x", x) - .Finalize(net.NewOperatorDef()); - } else { - OpDefBuilder("CWise", "CWiseBM") - .Input("Input") - .Output("Output") - .AddIntArg("type", type) - .AddFloatArg("x", x) - .Finalize(net.NewOperatorDef()); - } - - // Warm-up - for (int i = 0; i < 5; ++i) { - net.RunOp(D); - } - net.Sync(); - - mace::testing::StartTiming(); - while (iters--) { - net.RunOp(D); - } - net.Sync(); -} -} // namespace - -#define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \ - static void \ - BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \ - int iters) { \ - const int64_t tot = static_cast(iters) * N * C * H * W; \ - mace::testing::MaccProcessed(tot); \ - mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - CWise(iters, N, C, H, W, X, G); \ - } \ - BENCHMARK( \ - BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE) - -#define BM_CWISE(N, C, H, W, X, G) \ - BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \ - BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \ - BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL); - -BM_CWISE(1, 1, 512, 512, 2, 0); -BM_CWISE(1, 3, 128, 128, 2, 1); -BM_CWISE(1, 3, 512, 512, 2, 4); -BM_CWISE(1, 32, 112, 112, 2, 5); -BM_CWISE(1, 32, 112, 112, 2, 6); -BM_CWISE(1, 32, 112, 112, 2, 7); -BM_CWISE(1, 64, 256, 256, 3, 0); -BM_CWISE(1, 64, 512, 512, 3, 1); -BM_CWISE(1, 128, 56, 56, 3, 4); -BM_CWISE(1, 128, 256, 256, 3, 5); -BM_CWISE(1, 64, 512, 512, 3, 6); -BM_CWISE(1, 64, 512, 512, 3, 7); -BM_CWISE(1, 256, 14, 14, 3, 0); -BM_CWISE(1, 512, 14, 14, 3, 1); -BM_CWISE(1, 1024, 7, 7, 3, 4); -BM_CWISE(32, 1, 256, 256, 3, 5); -BM_CWISE(32, 1, 256, 256, 3, 6); -BM_CWISE(32, 1, 256, 256, 3, 7); - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/cwise_test.cc b/mace/ops/cwise_test.cc deleted file mode 100644 index e5510106ee71b02ef58edfc4894ef4e7b4a9973d..0000000000000000000000000000000000000000 --- a/mace/ops/cwise_test.cc +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/core/operator.h" -#include "mace/ops/ops_test_util.h" -#include "../kernels/cwise.h" - -namespace mace { -namespace ops { -namespace test { - -class CWiseOpTest : public OpsTestBase {}; - -namespace { -template -void Simple(const kernels::CWiseType type, - const std::vector &shape, - const std::vector &input0, - const float x, - const std::vector &output) { - // Construct graph - OpsTestNet net; - - // Add input data - net.AddInputFromArray("Input1", shape, input0); - - if (D == DeviceType::CPU) { - OpDefBuilder("CWise", "CWiseTest") - .Input("Input1") - .AddIntArg("type", static_cast(type)) - .AddFloatArg("x", x) - .Output("Output") - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(D); - } else { - BufferToImage(&net, "Input1", "InputImg1", - kernels::BufferType::IN_OUT_CHANNEL); - OpDefBuilder("CWise", "CWiseTest") - .Input("InputImg1") - .AddIntArg("type", static_cast(type)) - .AddFloatArg("x", x) - .Output("OutputImg") - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(D); - - ImageToBuffer(&net, "OutputImg", "Output", - kernels::BufferType::IN_OUT_CHANNEL); - } - - auto expected = CreateTensor(shape, output); - - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); -} -} // namespace - -TEST_F(CWiseOpTest, CPUSimple) { - Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); - - Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); - - Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); - - Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); - - Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); - - Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, - {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); -} - -TEST_F(CWiseOpTest, GPUSimple) { - Simple(kernels::CWiseType::MUL, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6}); - - Simple(kernels::CWiseType::ADD, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8}); - - Simple(kernels::CWiseType::DIV, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60}); - - Simple(kernels::CWiseType::SUB, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4}); - - Simple(kernels::CWiseType::NEG, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6}); - - Simple(kernels::CWiseType::ABS, {1, 1, 2, 3}, - {1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6}); -} - -namespace { -template -void RandomTest(const kernels::CWiseType type, - const std::vector &shape) { - testing::internal::LogToStderr(); - srand(time(NULL)); - - // Construct graph - OpsTestNet net; - - // Add input data - net.AddRandomInput("Input1", shape); - - OpDefBuilder("CWise", "CWiseTest") - .Input("Input1") - .AddIntArg("type", static_cast(type)) - .AddFloatArg("x", 1.2) - .Output("Output") - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(); - - BufferToImage(&net, "Input1", "InputImg1", - kernels::BufferType::IN_OUT_CHANNEL); - - OpDefBuilder("CWise", "CWiseTest") - .Input("InputImg1") - .AddIntArg("type", static_cast(type)) - .AddFloatArg("x", 1.2) - .AddIntArg("T", static_cast(DataTypeToEnum::value)) - .Output("OutputImg") - .Finalize(net.NewOperatorDef()); - - // Run - net.RunOp(D); - - ImageToBuffer(&net, "OutputImg", "OPENCLOutput", - kernels::BufferType::IN_OUT_CHANNEL); - - if (DataTypeToEnum::value == DT_FLOAT) { - ExpectTensorNear(*net.GetTensor("Output"), - *net.GetOutput("OPENCLOutput"), 1e-5, 1e-4); - } else { - ExpectTensorNear(*net.GetTensor("Output"), - *net.GetOutput("OPENCLOutput"), 1e-2, 1e-2); - } -} -} // namespace - -TEST_F(CWiseOpTest, OPENCLRandomFloat) { - RandomTest(kernels::CWiseType::MUL, - {3, 23, 37, 19}); - RandomTest(kernels::CWiseType::ADD, - {13, 32, 32, 64}); - RandomTest(kernels::CWiseType::SUB, - {3, 32, 32, 64}); - RandomTest(kernels::CWiseType::DIV, - {13, 32, 32, 64}); - RandomTest(kernels::CWiseType::NEG, - {13, 32, 32, 64}); -} - -TEST_F(CWiseOpTest, OPENCLRandomHalf) { - RandomTest(kernels::CWiseType::MUL, - {3, 23, 37, 19}); - RandomTest(kernels::CWiseType::ADD, - {13, 32, 32, 64}); - RandomTest(kernels::CWiseType::SUB, - {3, 32, 32, 64}); - RandomTest(kernels::CWiseType::DIV, - {13, 32, 32, 64}); - RandomTest(kernels::CWiseType::NEG, - {13, 32, 32, 64}); -} - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/eltwise.h b/mace/ops/eltwise.h index 818fa5e533d219256f3528f01c5434e99d390a40..2972a83ad20f2a7c62dbfc28d1ee54ce79f27c9b 100644 --- a/mace/ops/eltwise.h +++ b/mace/ops/eltwise.h @@ -32,24 +32,53 @@ class EltwiseOp : public Operator { OperatorBase::GetRepeatedArgument("coeff")) {} bool Run(StatsFuture *future) override { - const Tensor *input0 = this->Input(0); - const Tensor *input1 = this->Input(1); - Tensor *output = this->Output(OUTPUT); - MACE_CHECK(input0->dim_size() == input1->dim_size()) + if (this->InputSize() == 1) { + const Tensor* input = this->Input(0); + Tensor *output = this->Output(OUTPUT); + start_axis_ = input->dim_size() - 1; + is_scaler_ = true; + output->ResizeLike(input); + const float x = OperatorBase::GetSingleArgument("x", 1.0); + functor_(input, nullptr, start_axis_, + is_scaler_, x, false, output, future); + } else { + const index_t size0 = this->Input(0)->size(); + const index_t size1 = this->Input(1)->size(); + const bool swap = (size0 < size1); + const Tensor *input0 = swap ? this->Input(1) : this->Input(0); + const Tensor *input1 = swap ? this->Input(0) : this->Input(1); + + Tensor *output = this->Output(OUTPUT); + MACE_CHECK(input0->dim_size() == input1->dim_size()) << "Inputs of Eltwise op must be same shape"; - for (int i = 0; i < input0->dim_size(); ++i) { - MACE_CHECK(input0->dim(i) == input1->dim(i)) - << "Inputs of Eltwise op must be same shape"; + start_axis_ = input0->dim_size() - 1; + is_scaler_ = (input1->size() == 1); + uint32_t compared_size = 1; + if (!is_scaler_) { + while (start_axis_ >= 0) { + MACE_CHECK(input0->dim(start_axis_) == input1->dim(start_axis_), + "Invalid inputs dimension at axis: ") << start_axis_ + << "input 0: " << input0->dim(start_axis_) + << "input 1: " << input1->dim(start_axis_); + compared_size *= input1->dim(start_axis_); + if (compared_size == input1->size()) { + break; + } + start_axis_--; + } + } + output->ResizeLike(input0); + const float x = OperatorBase::GetSingleArgument("x", 1.0); + functor_(input0, input1, start_axis_, + is_scaler_, x, swap, output, future); } - - output->ResizeLike(input0); - - functor_(input0, input1, output, future); return true; } private: kernels::EltwiseFunctor functor_; + index_t start_axis_; + bool is_scaler_; private: OP_OUTPUT_TAGS(OUTPUT); diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index ca24242b04bef9e55031a57e54cf76e79f06b7ff..6dd3b33da7f96da77d6e5dc26458e072042e6fcc 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -25,23 +25,26 @@ class EltwiseOpTest : public OpsTestBase {}; namespace { template void Simple(const kernels::EltwiseType type, - const std::vector &shape, + const std::vector &shape0, + const std::vector &shape1, const std::vector &input0, const std::vector &input1, const std::vector &output, + const float x = 1.f, const std::vector coeff = {}) { // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray("Input1", shape, input0); - net.AddInputFromArray("Input2", shape, input1); + net.AddInputFromArray("Input1", shape0, input0); + net.AddInputFromArray("Input2", shape1, input1); if (D == DeviceType::CPU) { OpDefBuilder("Eltwise", "EltwiseTest") .Input("Input1") .Input("Input2") .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) .AddFloatsArg("coeff", coeff) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -57,6 +60,7 @@ void Simple(const kernels::EltwiseType type, .Input("InputImg1") .Input("InputImg2") .AddIntArg("type", static_cast(type)) + .AddFloatArg("x", x) .AddFloatsArg("coeff", coeff) .Output("OutputImg") .Finalize(net.NewOperatorDef()); @@ -68,7 +72,7 @@ void Simple(const kernels::EltwiseType type, kernels::BufferType::IN_OUT_CHANNEL); } - auto expected = CreateTensor(shape, output); + auto expected = CreateTensor(shape0, output); ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } @@ -76,53 +80,200 @@ void Simple(const kernels::EltwiseType type, TEST_F(EltwiseOpTest, CPUSimple) { Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {3, 6, 9, 12, 15, 18}, {2, 1}); + {3, 6, 9, 12, 15, 18}, 1., {2, 1}); Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, {1, 2, 3, 4, 6, 6}); Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, {1, 1, 3, 3, 5, 6}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {0, 1, 0, 1, 1, 0}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24}, + {1, 2, 1, 2, 0.5, 0.25}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {1, 4, 9, 4, 10, 18}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {2, 4, 6, 5, 7, 9}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {3, 6, 9, 9, 12, 15}, 1., {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 3, 4, 5, 6}); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 1, 3, 1, 1, 3}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {0, 1, 0, 9, 16, 9}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 1, 4, 5, 2}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {2, 4, 6, 8, 10, 12}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {3, 4, 5, 6, 7, 8}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {4, 6, 8, 10, 12, 14}, 2, {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {3, 3, 3, 4, 5, 6}, 3); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {1, 2, 3, 3, 3, 3}, 3); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {0.5}, + {2, 4, 6, 8, 10, 12}, 0.5); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {4, 1, 0, 1, 4, 9}, 3); } TEST_F(EltwiseOpTest, GPUSimple) { Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {1, 4, 9, 16, 25, 36}); + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {1, 4, 9, 16, 25, 36}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {2, 4, 6, 8, 10, 12}); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, + {3, 6, 9, 12, 15, 18}, 1., {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {1, 2, 3, 4, 6, 6}); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {1, 1, 3, 3, 5, 6}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 2, 10, 24}, + {1, 2, 1, 2, 0.5, 0.25}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 2, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, + {0, 1, 0, 1, 1, 0}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {1, 4, 9, 4, 10, 18}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {2, 4, 6, 8, 10, 12}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {2, 4, 6, 5, 7, 9}); Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 2, 3, 4, 5, 6}, - {3, 6, 9, 12, 15, 18}, {2, 1}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 2, 3}, + {3, 6, 9, 9, 12, 15}, 1., {2, 1}); Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, - {1, 2, 3, 4, 6, 6}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 3, 4, 5, 6}); Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, - {1, 2, 3, 4, 5, 6}, {1, 1, 3, 3, 6, 6}, - {1, 1, 3, 3, 5, 6}); + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 1, 3, 1, 1, 3}); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {0, 1, 0, 9, 16, 9}); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 3}, + {1, 2, 3, 4, 5, 6}, {1, 1, 3}, + {1, 2, 1, 4, 5, 2}); + + Simple(kernels::EltwiseType::PROD, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {2, 4, 6, 8, 10, 12}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {3, 4, 5, 6, 7, 8}, 2); + Simple(kernels::EltwiseType::SUM, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {2}, + {4, 6, 8, 10, 12, 14}, 2, {2, 1}); + Simple(kernels::EltwiseType::MAX, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {3, 3, 3, 4, 5, 6}, 3); + Simple(kernels::EltwiseType::MIN, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {1, 2, 3, 3, 3, 3}, 3); + Simple(kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {3}, + {4, 1, 0, 1, 4, 9}, 3); + Simple(kernels::EltwiseType::DIV, {1, 1, 2, 3}, + {1, 1, 1, 1}, + {1, 2, 3, 4, 5, 6}, {0.5}, + {2, 4, 6, 8, 10, 12}, 0.5); } namespace { template void RandomTest(const kernels::EltwiseType type, - const std::vector &shape) { + const std::vector &shape1, + const std::vector &shape2) { testing::internal::LogToStderr(); srand(time(NULL)); // Construct graph OpsTestNet net; + bool is_divide = (type == kernels::EltwiseType::DIV); + // Add input data - net.AddRandomInput("Input1", shape); - net.AddRandomInput("Input2", shape); + net.AddRandomInput("Input1", shape1, true, is_divide); + net.AddRandomInput("Input2", shape2, true, is_divide); + + OpDefBuilder("Eltwise", "EltwiseTest") .Input("Input1") @@ -166,24 +317,110 @@ void RandomTest(const kernels::EltwiseType type, TEST_F(EltwiseOpTest, OPENCLRandomFloat) { RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, {3, 23, 37, 19}); RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, {13, 32, 32, 64}); RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, {3, 32, 32, 64}); RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 37, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 63}, + {1, 1, 32, 63}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 1, 19}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 1, 64}); } TEST_F(EltwiseOpTest, OPENCLRandomHalf) { RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, {3, 23, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 23, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 37, 19}); + RandomTest(kernels::EltwiseType::PROD, + {3, 23, 37, 19}, + {1, 1, 1, 19}); RandomTest(kernels::EltwiseType::SUM, - {13, 32, 32, 64}); + {13, 32, 32, 64}, + {1, 1, 1, 1}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SUM, + {13, 32, 32, 64}, + {1, 1, 32, 64}); RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, {3, 32, 32, 64}); + RandomTest(kernels::EltwiseType::MAX, + {3, 32, 32, 64}, + {1, 1, 32, 64}); RandomTest(kernels::EltwiseType::MIN, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::SQR_DIFF, + {13, 32, 32, 64}, + {1, 1, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, {13, 32, 32, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 1, 64}); + RandomTest(kernels::EltwiseType::DIV, + {13, 32, 32, 64}, + {1, 1, 32, 64}); } } // namespace test diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 162435705bef37254bcb1a68654c16b296777104..1439bf08c8adefca8524e68b6d34d74bec2deceb 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -150,7 +150,8 @@ class OpsTestNet { template void AddRandomInput(const std::string &name, const std::vector &shape, - bool positive = true) { + bool positive = true, + bool truncate = false) { Tensor *input = ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); @@ -162,14 +163,24 @@ class OpsTestNet { std::normal_distribution nd(0, 1); if (DataTypeToEnum::value == DT_HALF) { std::generate( - input_data, input_data + input->size(), [&gen, &nd, positive] { - return half_float::half_cast(positive ? std::abs(nd(gen)) - : nd(gen)); + input_data, input_data + input->size(), + [&gen, &nd, positive, truncate] { + float d = nd(gen); + if (truncate) { + if (std::abs(d) > 100.f) d = 100.f; + if (std::abs(d) < 0.001f) d = 0.001f; + } + return half_float::half_cast(positive ?std::abs(d) : d); }); } else { std::generate(input_data, input_data + input->size(), - [&gen, &nd, positive] { - return positive ? std::abs(nd(gen)) : nd(gen); + [&gen, &nd, positive, truncate] { + float d = nd(gen); + if (truncate) { + if (std::abs(d) > 100.f) d = 100.f; + if (std::abs(d) < 0.001f) d = 0.001f; + } + return (positive ?std::abs(d) : d); }); } } diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 0f0ca20b4be8f83aebca776e742dc8d65e33ffdd..fc057dd0f29b8509ca403ea158f5f228066eca28 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -829,37 +829,25 @@ class TFConverter(object): self.resolved_ops[op.name] = 1 self.unused_tensor.add(get_input_tensor(op, 1).name) - def convert_math(self, op, math_type): + def convert_eltwise(self, op, math_type): op_def = self.net_def.op.add() arg = op_def.arg.add() arg.name = 'T' arg.i = self.dt op_def.name = op.name - - if len(op.inputs) == 1: - op_def.type = "CWise" - op_def.input.extend([input.name for input in op.inputs]) - x_arg = op_def.arg.add() - x_arg.name = 'x' - x_arg.f = 0 - elif len(op.inputs) >= 2: + op_def.type = "Eltwise" + op_def.input.extend([input.name for input in op.inputs]) + x_value = op.get_attr('x') + if len(op.inputs) >= 2: input_tensor0 = get_input_tensor(op, 0) input_tensor1 = get_input_tensor(op, 1) - if input_tensor0.shape == input_tensor1.shape: - op_def.type = "Eltwise" - op_def.input.extend([input.name for input in op.inputs]) - else: - op_def.type = "CWise" - x_value = 0 - if len(input_tensor1.shape) == 4: - op_def.input.extend([op.inputs[1].name]) - x_value = get_input_tensor(op, 0).eval().astype(np.float32) - else: - op_def.input.extend([op.inputs[0].name]) - x_value = get_input_tensor(op, 1).eval().astype(np.float32) - x_arg = op_def.arg.add() - x_arg.name = 'x' - x_arg.f = x_value + if len(input_tensor0) == 1: + x_value = input_tensor0.eval().astype(np.float32) + elif len(input_tensor1) == 1: + x_value = input_tensor1.eval().astype(np.float32) + x_arg = op_def.arg.add() + x_arg.name = 'x' + x_arg.f = x_value type_arg = op_def.arg.add() type_arg.name = 'type' type_arg.i = math_type_mode[math_type] @@ -1156,11 +1144,11 @@ class TFConverter(object): elif op.type == 'SpaceToDepth': self.convert_depth_to_space(op, False) elif op.type in ['Neg', 'neg', 'Negative', 'negative']: - self.convert_math(op, 'NEG') + self.convert_eltwise(op, 'NEG') elif op.type == 'Mul': - self.convert_math(op, 'MUL') + self.convert_eltwise(op, 'MUL') elif op.type == 'Sub': - self.convert_math(op, 'SUB') + self.convert_eltwise(op, 'SUB') elif self.is_softmax(op): self.convert_softmax(op) elif op.type in ['Relu', 'Sigmoid', 'Tanh']: