diff --git a/docs/getting_started/op_lists.rst b/docs/getting_started/op_lists.rst index 9a1d5425b1a768a0759a06f18e2f4a2d27f4e1b2..7cc22d87858ac8e788f1052dd2db43afdb5fe7af 100644 --- a/docs/getting_started/op_lists.rst +++ b/docs/getting_started/op_lists.rst @@ -7,22 +7,25 @@ Operator lists :header: "Operator","Android NN","Supported","Remark" "AVERAGE_POOL_2D","Y","Y","" + "ARGMAX","","Y","Only CPU and tensorflow is supported" "BATCH_NORM","","Y","Fusion with activation is supported" "BATCH_TO_SPACE_ND","Y","Y","" "BIAS_ADD","","Y","" + "CAST","","Y","Only CPU and tensorflow model is supported" "CHANNEL_SHUFFLE","","Y","" "CONCATENATION","Y","Y","Only support channel axis concatenation" "CONV_2D","Y","Y","Fusion with BN and activation layer is supported" - "DECONV_2D","N","Y","Only tensorflow model is supported" + "DECONV_2D","","Y","Only tensorflow model is supported" "DEPTHWISE_CONV_2D","Y","Y","Only multiplier = 1 is supported; Fusion is supported" "DEPTH_TO_SPACE","Y","Y","" "DEQUANTIZE","Y","Y","Model quantization will be supported later" - "ELEMENT_WISE","Y","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW" - "EMBEDDING_LOOKUP","Y","","" + "ELEMENT_WISE","Y","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" + "EMBEDDING_LOOKUP","Y","Y","Only support channel axis concatenation" "FLOOR","Y","","" "FULLY_CONNECTED","Y","Y","" "GROUP_CONV_2D","","","Caffe model with group count = channel count is supported" "HASHTABLE_LOOKUP","Y","","" + "IDENTITY","","Y","Only tensorflow model is supported" "L2_NORMALIZATION","Y","","" "L2_POOL_2D","Y","","" "LOCAL_RESPONSE_NORMALIZATION","Y","Y","" @@ -31,9 +34,10 @@ Operator lists "LSTM","Y","","" "MATMUL","","Y","" "MAX_POOL_2D","Y","Y","" - "PAD", "N","Y","" + "PAD", "Y","Y","" "PSROI_ALIGN","","Y","" "PRELU","","Y","Only caffe model is supported" + "REDUCE_MEAN","Y","Y","Only tensorflow model is supported" "RELU","Y","Y","" "RELU1","Y","Y","" "RELU6","Y","Y","" @@ -42,9 +46,14 @@ Operator lists "RESIZE_BILINEAR","Y","Y","" "RNN","Y","","" "RPN_PROPOSAL_LAYER","","Y","" - "SLICE","N","Y","Only support channel axis slice" + "SHAPE","","Y","Only CPU and tensorflow is supported" + "STACK","","Y","Only CPU and tensorflow is supported" + "STRIDEDSLICE","Y","Y","Only CPU and tensorflow is supported" + "SLICE","","Y","In tensorflow, this op is equivalent to SPLIT; Only support channel axis slice" "SOFTMAX","Y","Y","" "SPACE_TO_BATCH_ND","Y", "Y","" "SPACE_TO_DEPTH","Y","Y","" + "SQEEZE","Y","Y","Only CPU and tensorflow is supported" "SVDF","Y","","" "TANH","Y","Y","" + "TRANSPOSE","Y","Y","Only CPU and tensorflow is supported" diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 000c85b74c7c996a9f297adeb924521aeaa32eb6..ff78d369a2d091ca2c84d7d3a134c2de6a77b82e 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -264,7 +264,6 @@ MaceStatus MaceEngine::Impl::Run( auto shape = output_tensor->shape(); int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0"; MACE_CHECK(shape == output.second.shape()) << "Output shape mismatch: " << MakeString(output.second.shape()) diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 9add3e1a08e9235301058d9460955a099155334d..6389d1172210fd65db9d62ce029e95be0909e0b3 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -76,6 +76,7 @@ namespace ops { // Keep in lexicographical order extern void Register_Activation(OperatorRegistry *op_registry); extern void Register_AddN(OperatorRegistry *op_registry); +extern void Register_ArgMax(OperatorRegistry *op_registry); extern void Register_BatchNorm(OperatorRegistry *op_registry); extern void Register_BatchToSpaceND(OperatorRegistry *op_registry); extern void Register_BiasAdd(OperatorRegistry *op_registry); @@ -124,6 +125,7 @@ OperatorRegistry::OperatorRegistry() { // Keep in lexicographical order ops::Register_Activation(this); ops::Register_AddN(this); + ops::Register_ArgMax(this); ops::Register_BatchNorm(this); ops::Register_BatchToSpaceND(this); ops::Register_BiasAdd(this); diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index f0efde525367447cfe19d9f15e75b221a73c8d9d..3d03345d5949028d246020fa8f43e26dafa0fe08 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -157,6 +157,8 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, } } MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid."); + // TODO(liyin): memory block should not have concept of type, but to be + // consistent with gpu, all memory block use float/half as unit for (auto &mem_block : net_def.mem_arena().mem_block()) { if (device_type == DeviceType::GPU) { // TODO(liuqi): refactor based on PB @@ -191,8 +193,15 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, auto mem_ids = op.mem_id(); int count = mem_ids.size(); for (int i = 0; i < count; ++i) { + DataType output_type; + if (i < op.output_type_size()) { + output_type = op.output_type(i); + } else { + output_type = dtype; + } std::unique_ptr tensor - (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), dtype)); + (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), + output_type)); tensor->SetSourceOpName(op.name()); if (device_type == DeviceType::GPU) { VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" diff --git a/mace/kernels/argmax.h b/mace/kernels/argmax.h new file mode 100644 index 0000000000000000000000000000000000000000..54edc3ee7b718a69b7b7136dbba587f07d654997 --- /dev/null +++ b/mace/kernels/argmax.h @@ -0,0 +1,85 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_ARGMAX_H_ +#define MACE_KERNELS_ARGMAX_H_ + +#include +#include +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" +#include "mace/public/mace.h" +#include "mace/utils/utils.h" + +namespace mace { +namespace kernels { + +template +struct ArgMaxFunctor { + MaceStatus operator()(const Tensor *input, + const Tensor *axis, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar"); + MACE_CHECK(axis->dim_size() == 0, "Mace argmax only supports scalar axis"); + Tensor::MappingGuard axis_guard(axis); + int axis_value = axis->data()[0]; + if (axis_value < 0) { + axis_value += input->dim_size(); + } + MACE_CHECK(axis_value == input->dim_size() - 1, + "Mace argmax only supports last dimension as axis"); + + std::vector output_shape(input->dim_size() - 1); + for (index_t d = 0; d < input->dim_size() - 1; ++d) { + output_shape[d] = input->dim(d < axis_value ? d : d + 1); + } + MACE_RETURN_IF_ERROR(output->Resize(output_shape)); + + Tensor::MappingGuard input_guard(input); + Tensor::MappingGuard output_guard(output); + auto input_data = input->data(); + auto output_data = output->mutable_data(); + + index_t outer_size = output->size(); + index_t inner_size = input->dim(axis_value); + +#pragma omp parallel for + for (index_t i = 0; i < outer_size; ++i) { + int idx = 0; + T max_value = std::numeric_limits::lowest(); + const T *input_ptr = input_data + i * inner_size; + for (index_t j = 0; j < inner_size; ++j) { + if (input_ptr[j] > max_value) { + max_value = input_ptr[j]; + idx = j; + } + } + output_data[i] = idx; + } + + return MACE_SUCCESS; + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_ARGMAX_H_ diff --git a/mace/kernels/eltwise.h b/mace/kernels/eltwise.h index 05695fd8ed9aef808c3e7e10211e0a2f4591c41e..a246846b6d7283d3e8de01a452d7ad000da00c99 100644 --- a/mace/kernels/eltwise.h +++ b/mace/kernels/eltwise.h @@ -16,10 +16,10 @@ #define MACE_KERNELS_ELTWISE_H_ #include +#include #include #include #include -#include #include "mace/core/future.h" #include "mace/core/tensor.h" @@ -42,9 +42,12 @@ enum EltwiseType { ABS = 7, SQR_DIFF = 8, POW = 9, - NONE = 10, + EQUAL = 10, + NONE = 11, }; +static bool IsLogicalType(EltwiseType type) { return type == EQUAL; } + inline index_t GetIndex(const std::vector &shape, const std::vector &index) { index_t idx = 0; @@ -68,22 +71,19 @@ inline void IncreaseIndex(const std::vector &shape, } } -inline void TensorGeneralBroadcastEltwise(const EltwiseType type, - const float *input0, - const float *input1, - const std::vector &coeff, - const bool swapped, - const std::vector - &input0_shape, - const std::vector - &input1_shape, - const std::vector - &output_shape, - float *output) { - const index_t output_size = std::accumulate(output_shape.begin(), - output_shape.end(), - 1, - std::multiplies()); +template +inline void TensorGeneralBroadcastEltwise( + const EltwiseType type, + const T *input0, + const T *input1, + const std::vector &coeff, + const bool swapped, + const std::vector &input0_shape, + const std::vector &input1_shape, + const std::vector &output_shape, + DstType *output) { + const index_t output_size = std::accumulate( + output_shape.begin(), output_shape.end(), 1, std::multiplies()); std::vector out_index(output_shape.size(), 0); switch (type) { case SUM: @@ -191,19 +191,28 @@ inline void TensorGeneralBroadcastEltwise(const EltwiseType type, } } break; + case EQUAL: + for (index_t i = 0; i < output_size; ++i) { + const index_t idx0 = GetIndex(input0_shape, out_index); + const index_t idx1 = GetIndex(input1_shape, out_index); + output[i] = input1[idx1] == input0[idx0]; + IncreaseIndex(output_shape, &out_index); + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } } +template inline void TensorBroadcastEltwise(const EltwiseType type, - const float *input0, - const float *input1, + const T *input0, + const T *input1, const std::vector &coeff, const index_t diff_size, const index_t common_size, const bool swapped, - float *output) { + DstType *output) { switch (type) { case SUM: if (coeff.empty()) { @@ -333,19 +342,29 @@ inline void TensorBroadcastEltwise(const EltwiseType type, output[i] = std::fabs(input0[i]); } break; + case EQUAL: +#pragma omp parallel for collapse(2) + for (index_t d = 0; d < diff_size; ++d) { + for (index_t i = 0; i < common_size; ++i) { + output[i + d * common_size] = + input0[i + d * common_size] == input1[i]; + } + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } } // Multiplication is costly, so we specialize the following case. +template inline void TensorEltwise(const EltwiseType type, - const float *input0, - const float *input1, + const T *input0, + const T *input1, const std::vector &coeff, const index_t size, const bool swapped, - float *output) { + DstType *output) { switch (type) { case SUM: if (coeff.empty()) { @@ -445,19 +464,26 @@ inline void TensorEltwise(const EltwiseType type, output[i] = std::fabs(input0[i]); } break; + case EQUAL: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output[i] = input0[i] == input1[i]; + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } } // Multiplication is costly, so we specialize the following case. +template inline void TensorScalarEltwise(const EltwiseType type, - const float *input0, - const float input1, + const T *input0, + const T input1, const std::vector &coeff, const index_t size, const bool swapped, - float *output) { + DstType *output) { switch (type) { case SUM: if (coeff.empty()) { @@ -556,31 +582,39 @@ inline void TensorScalarEltwise(const EltwiseType type, for (index_t i = 0; i < size; ++i) { output[i] = std::fabs(input0[i]); } + break; + case EQUAL: +#pragma omp parallel for + for (index_t i = 0; i < size; ++i) { + output[i] = input0[i] == input1; + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } } +template inline void TensorEltwisePerChannel(const EltwiseType type, - const float *input0, - const float *input1, + const T *input0, + const T *input1, const std::vector &coeff, const index_t batch0, const index_t batch1, const index_t channel, const index_t image_size, const bool swapped, - float *output) { + DstType *output) { switch (type) { case SUM: if (coeff.empty()) { #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in0_ptr[i] + in1_ptr[c]; } @@ -594,9 +628,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in0_ptr[i] * coeff_copy[0] + in1_ptr[c] * coeff_copy[1]; @@ -610,9 +644,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in0_ptr[i] - in1_ptr[c]; } @@ -622,9 +656,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in1_ptr[c] - in0_ptr[i]; } @@ -636,9 +670,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in0_ptr[i] * in1_ptr[c]; } @@ -650,9 +684,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in0_ptr[i] / in1_ptr[c]; } @@ -662,9 +696,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = in1_ptr[c] / in0_ptr[i]; } @@ -676,9 +710,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = std::min(in0_ptr[i], in1_ptr[c]); } @@ -689,9 +723,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = std::max(in0_ptr[i], in1_ptr[c]); } @@ -702,9 +736,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = std::pow(in0_ptr[i] - in1_ptr[c], 2.f); } @@ -716,9 +750,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = std::pow(in0_ptr[i], in1_ptr[c]); } @@ -728,9 +762,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, #pragma omp parallel for collapse(2) for (index_t b = 0; b < batch0; ++b) { for (index_t c = 0; c < channel; ++c) { - const float *in0_ptr = input0 + ((b * channel) + c) * image_size; - const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); - float *out_ptr = output + ((b * channel) + c) * image_size; + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; for (index_t i = 0; i < image_size; ++i) { out_ptr[i] = std::pow(in1_ptr[c], in0_ptr[i]); } @@ -750,6 +784,19 @@ inline void TensorEltwisePerChannel(const EltwiseType type, output[i] = std::fabs(input0[i]); } break; + case EQUAL: +#pragma omp parallel for collapse(2) + for (index_t b = 0; b < batch0; ++b) { + for (index_t c = 0; c < channel; ++c) { + const T *in0_ptr = input0 + ((b * channel) + c) * image_size; + const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); + DstType *out_ptr = output + ((b * channel) + c) * image_size; + for (index_t i = 0; i < image_size; ++i) { + out_ptr[i] = in0_ptr[i] == in1_ptr[c]; + } + } + } + break; default: LOG(FATAL) << "Eltwise op not support type " << type; } @@ -769,30 +816,17 @@ struct EltwiseFunctorBase { }; template -struct EltwiseFunctor; - -template <> -struct EltwiseFunctor : EltwiseFunctorBase { +struct EltwiseFunctor : EltwiseFunctorBase { EltwiseFunctor(const EltwiseType type, const std::vector &coeff, - const float value, + const float value, // keep it float as it comes from arg const DataFormat data_format) : EltwiseFunctorBase(type, coeff, value, data_format) {} - MaceStatus operator()(const Tensor *input0, - const Tensor *input1, - Tensor *output, - StatsFuture *future) { - MACE_UNUSED(future); - - if (input1 == nullptr) { - scalar_tensor_.Resize({}); - Tensor::MappingGuard guard(&scalar_tensor_); - auto scalar_data = scalar_tensor_.mutable_data(); - scalar_data[0] = value_; - input1 = &scalar_tensor_; - } - + template + MaceStatus DoEltwise(const Tensor *input0, + const Tensor *input1, + Tensor *output) { bool swapped = false; if (input0->size() < input1->size()) { std::swap(input0, input1); @@ -804,20 +838,16 @@ struct EltwiseFunctor : EltwiseFunctorBase { static_cast(input0->dim_size() - input1->dim_size()); if (data_format_ == NCHW) { MACE_CHECK( - (input0->dim_size() == 4) - && ((input1->dim_size() == 0) - || (input1->dim_size() == 4 - && input1->dim(1) == input0->dim(1) - && (input1->dim(0) == input0->dim(0) - || input1->dim(0) == 1)) - || (input1->dim_size() == 1 - && input1->dim(0) == input0->dim(1))), + (input0->dim_size() == 4) && + ((input1->dim_size() == 0) || + (input1->dim_size() == 4 && input1->dim(1) == input0->dim(1) && + (input1->dim(0) == input0->dim(0) || input1->dim(0) == 1)) || + (input1->dim_size() == 1 && input1->dim(0) == input0->dim(1))), "only support broadcast channel dimension"); } else { for (uint32_t i = 0; i < input1->dim_size(); ++i) { - MACE_CHECK(input0->dim(rank_diff + i) == 1 - || input1->dim(i) == 1 - || input0->dim(rank_diff + i) == input1->dim(i), + MACE_CHECK(input0->dim(rank_diff + i) == 1 || input1->dim(i) == 1 || + input0->dim(rank_diff + i) == input1->dim(i), "Element-Wise op only support tail dimensions broadcast"); } } @@ -825,14 +855,14 @@ struct EltwiseFunctor : EltwiseFunctorBase { Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input1_guard(input1); - const float *input0_ptr = input0->data(); - const float *input1_ptr = input1->data(); + const T *input0_ptr = input0->data(); + const T *input1_ptr = input1->data(); if (data_format_ == NCHW && input1->dim_size() > 0 && input1->size() < input0->size()) { MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); Tensor::MappingGuard output_guard(output); - float *output_ptr = output->mutable_data(); + DstType *output_ptr = output->mutable_data(); TensorEltwisePerChannel( type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), @@ -841,8 +871,7 @@ struct EltwiseFunctor : EltwiseFunctorBase { } else { const std::vector &input0_shape = input0->shape(); std::vector input1_shape(rank_diff, 1); - input1_shape.insert(input1_shape.end(), - input1->shape().begin(), + input1_shape.insert(input1_shape.end(), input1->shape().begin(), input1->shape().end()); std::vector output_shape(input0->dim_size(), 0); @@ -851,27 +880,21 @@ struct EltwiseFunctor : EltwiseFunctorBase { } MACE_RETURN_IF_ERROR(output->Resize(output_shape)); Tensor::MappingGuard output_guard(output); - float *output_ptr = output->mutable_data(); + DstType *output_ptr = output->mutable_data(); bool need_general_broadcast = false; for (uint32_t i = 0; i < input1->dim_size(); ++i) { - if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1) - || (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) { + if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1) || + (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) { need_general_broadcast = true; break; } } if (need_general_broadcast) { - TensorGeneralBroadcastEltwise(type_, - input0_ptr, - input1_ptr, - coeff_, - swapped, - input0_shape, - input1_shape, - output_shape, - output_ptr); + TensorGeneralBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_, + swapped, input0_shape, input1_shape, + output_shape, output_ptr); } else if (input1->size() == input0->size()) { TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), swapped, output_ptr); @@ -891,6 +914,28 @@ struct EltwiseFunctor : EltwiseFunctorBase { return MACE_SUCCESS; } + MaceStatus operator()(const Tensor *input0, + const Tensor *input1, + Tensor *output, + StatsFuture *future) { + MACE_UNUSED(future); + + if (input1 == nullptr) { + scalar_tensor_.Resize({}); + Tensor::MappingGuard guard(&scalar_tensor_); + auto scalar_data = scalar_tensor_.mutable_data(); + scalar_data[0] = static_cast(value_); + input1 = &scalar_tensor_; + } + + if (IsLogicalType(type_)) { + // as we do not have bool-type tensor, we use int type + return DoEltwise(input0, input1, output); + } else { + return DoEltwise(input0, input1, output); + } + } + Tensor scalar_tensor_; }; diff --git a/mace/kernels/strided_slice.h b/mace/kernels/strided_slice.h index 8f09610b98ea374567b25524b957593b194bdc86..eab4a4d5441ed36c6d8f779209127187ed7a6d5a 100644 --- a/mace/kernels/strided_slice.h +++ b/mace/kernels/strided_slice.h @@ -67,6 +67,26 @@ struct StridedSliceFunctor { const T *input_data = input->data(); const int32_t *begin_indices_data = begin_indices->data(); const int32_t *end_indices_data = end_indices->data(); + const int32_t *strides_data = strides->data(); + + std::vector pad_begin_indices(input->dim_size(), 0); + std::vector pad_end_indices(input->dim_size(), 0); + std::vector pad_strides_indices(input->dim_size(), 1); + + if (begin_indices->size() < input->dim_size()) { + for (index_t i = 0; i < begin_indices->size(); ++i) { + pad_begin_indices[i] = begin_indices_data[i]; + pad_end_indices[i] = end_indices_data[i]; + pad_strides_indices[i] = strides_data[i]; + } + for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) { + pad_end_indices[i] = input->dim(i); + } + begin_indices_data = pad_begin_indices.data(); + end_indices_data = pad_end_indices.data(); + strides_data = pad_strides_indices.data(); + } + std::vector slice_end_data; if (is_slice_) { // if this op is slice, the end_indices_data is size actually @@ -80,7 +100,6 @@ struct StridedSliceFunctor { } end_indices_data = slice_end_data.data(); } - const int32_t *strides_data = strides->data(); std::vector output_shape; std::vector real_begin_indices(input->dim_size(), 0); diff --git a/mace/ops/argmax.cc b/mace/ops/argmax.cc new file mode 100644 index 0000000000000000000000000000000000000000..977cbbc6b238b1f909ca4e5ce06c5c81cc9ea36f --- /dev/null +++ b/mace/ops/argmax.cc @@ -0,0 +1,29 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/argmax.h" + +namespace mace { +namespace ops { + +void Register_ArgMax(OperatorRegistry *op_registry) { + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ArgMax") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + ArgMaxOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/argmax.h b/mace/ops/argmax.h new file mode 100644 index 0000000000000000000000000000000000000000..ce493059387fc0d6aff802b7db053b9e47c8cfcb --- /dev/null +++ b/mace/ops/argmax.h @@ -0,0 +1,49 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_ARGMAX_H_ +#define MACE_OPS_ARGMAX_H_ + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/argmax.h" + +namespace mace { +namespace ops { + +template +class ArgMaxOp : public Operator { + public: + ArgMaxOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + + MaceStatus Run(StatsFuture *future) override { + const Tensor *input = this->Input(0); + const Tensor *axis = this->Input(1); + Tensor *output = this->Output(0); + return functor_(input, axis, output, future); + } + + private: + kernels::ArgMaxFunctor functor_; + + MACE_OP_INPUT_TAGS(INPUT, AXIS); + MACE_OP_OUTPUT_TAGS(OUTPUT); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_ARGMAX_H_ diff --git a/mace/ops/argmax_test.cc b/mace/ops/argmax_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf00b57933969f394c61593baa10518085b3c92a --- /dev/null +++ b/mace/ops/argmax_test.cc @@ -0,0 +1,68 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ArgMaxOpTest : public OpsTestBase {}; + +namespace { +template +void ArgMaxTest(const std::vector &input_shape, + const std::vector &input, + const std::vector &output_shape, + const std::vector &output) { + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", input_shape, input); + net.AddInputFromArray("axis", {}, {-1}); + + if (D == DeviceType::CPU) { + OpDefBuilder("ArgMax", "ArgMaxTest") + .Input("Input") + .Input("axis") + .Output("Output") + .OutputType({DT_INT32}) + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } else { + MACE_NOT_IMPLEMENTED; + } + + // Check + auto expected = CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} +} // namespace + +TEST_F(ArgMaxOpTest, Vector) { ArgMaxTest({3}, {-3, -1, -2}, {}, {1}); } + +TEST_F(ArgMaxOpTest, Matrix) { + ArgMaxTest({3, 3}, {4, 5, 6, 9, 8, 7, 1, 2, 3}, {3}, {2, 0, 2}); +} + +TEST_F(ArgMaxOpTest, HighRank) { + ArgMaxTest({1, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, + {1, 2, 2}, {2, 2, 2, 2}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/cast.h b/mace/ops/cast.h index a8b283d01fc873679f5ee5a3cb5d717a1dddc893..cee022ec4aedb9b848e9dc46b3e564e561c08b36 100644 --- a/mace/ops/cast.h +++ b/mace/ops/cast.h @@ -22,11 +22,11 @@ namespace mace { namespace ops { -template -class CastOp : public Operator { +template +class CastOp : public Operator { public: CastOp(const OperatorDef &op_def, Workspace *ws) - : Operator(op_def, ws) {} + : Operator(op_def, ws) {} MaceStatus Run(StatsFuture *future) override { MACE_UNUSED(future); @@ -36,17 +36,16 @@ class CastOp : public Operator { Tensor::MappingGuard input_guard(input); Tensor::MappingGuard output_guard(output); - auto src_dtype = input->dtype(); - - auto output_data = output->mutable_data
(); + auto dst_dtype = output->dtype(); #define MACE_CAST_COPY \ - auto input_data = input->data(); \ + auto output_data = output->mutable_data(); \ + auto input_data = input->data(); \ for (index_t i = 0; i < output->size(); ++i) { \ - output_data[i] = static_cast
(input_data[i]); \ + output_data[i] = static_cast(input_data[i]); \ } - MACE_RUN_WITH_TYPE_ENUM(src_dtype, MACE_CAST_COPY); + MACE_RUN_WITH_TYPE_ENUM(dst_dtype, MACE_CAST_COPY); return MACE_SUCCESS; } diff --git a/mace/ops/cast_test.cc b/mace/ops/cast_test.cc index e483f429054dcc916f5fc9c7031d0f43a1d141cf..f35d3af6ee6237db274e891e6db37e1fd31fa366 100644 --- a/mace/ops/cast_test.cc +++ b/mace/ops/cast_test.cc @@ -30,8 +30,9 @@ void TestCast(const std::vector &input_shape, OpsTestNet net; OpDefBuilder("Cast", "CastTest") .Input("Input") + .OutputType({DataTypeToEnum::v()}) .Output("Output") - .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("T", DataTypeToEnum::v()) .Finalize(net.NewOperatorDef()); // Add input data @@ -55,10 +56,12 @@ void TestCast(const std::vector &input_shape, TEST_F(CastOpTest, TestCastFromFloatToInt32) { TestCast({1, 2, 3}, {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}); + TestCast({}, {3.3}); } TEST_F(CastOpTest, TestCastFromInt32ToFloat) { TestCast({1, 2, 3}, {1, 2, 3, 4, 5, 6}); + TestCast({}, {3}); } } // namespace test diff --git a/mace/ops/eltwise.cc b/mace/ops/eltwise.cc index bbb214352d3007e53d87ac859050bee5146d59d0..81050b16e4f5e030e6ff210f9022c3f866cdbe6c 100644 --- a/mace/ops/eltwise.cc +++ b/mace/ops/eltwise.cc @@ -23,6 +23,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), EltwiseOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + EltwiseOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") diff --git a/mace/ops/eltwise_test.cc b/mace/ops/eltwise_test.cc index 84ec9d06f6a5f77b1633cfebe00923373b2b768b..ddf113d8f9dd3ef6db847af24332f1d5eb35b918 100644 --- a/mace/ops/eltwise_test.cc +++ b/mace/ops/eltwise_test.cc @@ -23,30 +23,63 @@ namespace test { class EltwiseOpTest : public OpsTestBase {}; namespace { -template +template +void SimpleScalarScalar(const kernels::EltwiseType type, + const T input, + const float x, + const DstType output) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {}, {input}); + + if (D == DeviceType::CPU) { + OpDefBuilder("Eltwise", "EltwiseTest") + .Input("Input") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(type)) + .AddFloatArg("value", x) + .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } else { + MACE_NOT_IMPLEMENTED; + } + + auto expected = CreateTensor({}, {output}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +template void SimpleTensorScalar(const kernels::EltwiseType type, const std::vector &shape, - const std::vector &input, + const std::vector &input, const float x, - const std::vector &output) { + const std::vector &output) { // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray("Input", shape, input); + net.AddInputFromArray("Input", shape, input); if (D == DeviceType::CPU) { - net.TransformDataFormat("Input", NHWC, "TInput", NCHW); + net.TransformDataFormat("Input", NHWC, "TInput", NCHW); OpDefBuilder("Eltwise", "EltwiseTest") .Input("TInput") + .AddIntArg("T", DataTypeToEnum::v()) .AddIntArg("type", static_cast(type)) .AddFloatArg("value", x) .AddIntArg("data_format", DataFormat::NCHW) + .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .Output("TOutput") .Finalize(net.NewOperatorDef()); // Run net.RunOp(D); - net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); + net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); } else { BufferToImage(&net, "Input", "InputImg", kernels::BufferType::IN_OUT_CHANNEL); @@ -60,44 +93,47 @@ void SimpleTensorScalar(const kernels::EltwiseType type, // Run net.RunOp(D); - ImageToBuffer(&net, "OutputImg", "Output", - kernels::BufferType::IN_OUT_CHANNEL); + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); } - auto expected = CreateTensor(shape, output); + auto expected = CreateTensor(shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } -template +template void SimpleTensorEltwise(const kernels::EltwiseType type, const std::vector &shape0, - const std::vector &input0, + const std::vector &input0, const std::vector &shape1, - const std::vector &input1, - const std::vector &output, + const std::vector &input1, + const std::vector &output, const std::vector &coeff = {}) { // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray("Input0", shape0, input0); - net.AddInputFromArray("Input1", shape1, input1); + net.AddInputFromArray("Input0", shape0, input0); + net.AddInputFromArray("Input1", shape1, input1); if (D == DeviceType::CPU) { - auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest") - .AddIntArg("type", static_cast(type)) - .AddFloatsArg("coeff", coeff) - .AddIntArg("data_format", DataFormat::NCHW) - .Output("TOutput"); + auto op_builder = + OpDefBuilder("Eltwise", "EltwiseTest") + .AddIntArg("T", DataTypeToEnum::v()) + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", coeff) + .AddIntArg("data_format", DataFormat::NCHW) + .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) + .Output("TOutput"); if (shape0.size() > 1) { - net.TransformDataFormat("Input0", NHWC, "TInput0", NCHW); + net.TransformDataFormat("Input0", NHWC, "TInput0", NCHW); op_builder.Input("TInput0"); } else { op_builder.Input("Input0"); } if (shape1.size() > 1) { - net.TransformDataFormat("Input1", NHWC, "TInput1", NCHW); + net.TransformDataFormat("Input1", NHWC, "TInput1", NCHW); op_builder.Input("TInput1"); } else { op_builder.Input("Input1"); @@ -106,7 +142,7 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, // Run net.RunOp(D); - net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); + net.TransformDataFormat("TOutput", NCHW, "Output", NHWC); } else { BufferToImage(&net, "Input0", "InputImg0", kernels::BufferType::IN_OUT_CHANNEL); @@ -123,42 +159,45 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, // Run net.RunOp(D); - ImageToBuffer(&net, "OutputImg", "Output", - kernels::BufferType::IN_OUT_CHANNEL); + ImageToBuffer(&net, "OutputImg", "Output", + kernels::BufferType::IN_OUT_CHANNEL); } std::vector output_shape = shape0; if (input0.size() < input1.size()) { output_shape = shape1; } - auto expected = CreateTensor(output_shape, output); + auto expected = CreateTensor(output_shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } -template +template void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type, const std::vector &shape0, - const std::vector &input0, + const std::vector &input0, const std::vector &shape1, - const std::vector &input1, + const std::vector &input1, const std::vector &output_shape, - const std::vector &output, + const std::vector &output, const std::vector &coeff = {}) { // Construct graph OpsTestNet net; // Add input data - net.AddInputFromArray("Input0", shape0, input0); - net.AddInputFromArray("Input1", shape1, input1); + net.AddInputFromArray("Input0", shape0, input0); + net.AddInputFromArray("Input1", shape1, input1); if (D == DeviceType::CPU) { - auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest") - .Input("Input0") - .Input("Input1") - .AddIntArg("type", static_cast(type)) - .AddFloatsArg("coeff", coeff) - .Output("Output"); + auto op_builder = + OpDefBuilder("Eltwise", "EltwiseTest") + .AddIntArg("T", DataTypeToEnum::v()) + .Input("Input0") + .Input("Input1") + .AddIntArg("type", static_cast(type)) + .AddFloatsArg("coeff", coeff) + .OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) + .Output("Output"); op_builder.Finalize(net.NewOperatorDef()); // Run @@ -167,214 +206,248 @@ void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type, MACE_NOT_IMPLEMENTED; } - auto expected = CreateTensor(output_shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); + auto expected = CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); } } // namespace +TEST_F(EltwiseOpTest, CPUSimpleScalarScalar) { + SimpleScalarScalar( + kernels::EltwiseType::SUM, 1, 2, 3); + SimpleScalarScalar( + kernels::EltwiseType::SUB, 1, 2, -1); + SimpleScalarScalar( + kernels::EltwiseType::PROD, 1, 2, 2); + SimpleScalarScalar( + kernels::EltwiseType::DIV, 1, 2, 0.5); + SimpleScalarScalar( + kernels::EltwiseType::MIN, 1, 2, 1); + SimpleScalarScalar( + kernels::EltwiseType::MAX, 1, 2, 2); + SimpleScalarScalar( + kernels::EltwiseType::NEG, 1, 2, -1); + SimpleScalarScalar( + kernels::EltwiseType::ABS, -1, 3, 1); + SimpleScalarScalar( + kernels::EltwiseType::EQUAL, 1, 3, 0); + SimpleScalarScalar( + kernels::EltwiseType::EQUAL, 3, 3, 1); +} + TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) { - SimpleTensorScalar(kernels::EltwiseType::SUM, - {1, 1, 1, 1}, {1}, 1, {2}); - SimpleTensorScalar(kernels::EltwiseType::SUB, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {0, 1, 2, 3, 4, 5}); - SimpleTensorScalar(kernels::EltwiseType::PROD, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 2, {2, 4, 6, 8, 10, 12}); - SimpleTensorScalar(kernels::EltwiseType::DIV, - {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, - 2, {1, 2, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::MIN, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {1, 1, 1, 1, 1, 1}); - SimpleTensorScalar(kernels::EltwiseType::MAX, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 3, {3, 3, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::NEG, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 3, {-1, -2, -3, -4, -5, -6}); - SimpleTensorScalar( + SimpleTensorScalar(kernels::EltwiseType::SUM, + {1, 1, 1, 1}, {1}, 1, {2}); + SimpleTensorScalar( + kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {0, 1, 2, 3, 4, 5}); + SimpleTensorScalar( + kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2, + {2, 4, 6, 8, 10, 12}); + SimpleTensorScalar( + kernels::EltwiseType::DIV, {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, 2, + {1, 2, 3, 4, 5, 6}); + SimpleTensorScalar( + kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {1, 1, 1, 1, 1, 1}); + SimpleTensorScalar( + kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3, + {3, 3, 3, 4, 5, 6}); + SimpleTensorScalar( + kernels::EltwiseType::NEG, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3, + {-1, -2, -3, -4, -5, -6}); + SimpleTensorScalar( kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3, {1, 2, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::SQR_DIFF, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {0, 1, 4, 9, 16, 25}); + SimpleTensorScalar( + kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {0, 1, 4, 9, 16, 25}); + SimpleTensorScalar( + kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3, + {0, 0, 1, 0, 0, 0}); } TEST_F(EltwiseOpTest, GPUSimpleTensorScalar) { - SimpleTensorScalar(kernels::EltwiseType::SUM, - {1, 1, 1, 1}, {1}, 1, {2}); - SimpleTensorScalar(kernels::EltwiseType::SUB, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {0, 1, 2, 3, 4, 5}); - SimpleTensorScalar(kernels::EltwiseType::PROD, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 2, {2, 4, 6, 8, 10, 12}); - SimpleTensorScalar(kernels::EltwiseType::DIV, - {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, - 2, {1, 2, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::MIN, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {1, 1, 1, 1, 1, 1}); - SimpleTensorScalar(kernels::EltwiseType::MAX, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 3, {3, 3, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::NEG, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 3, {-1, -2, -3, -4, -5, -6}); - SimpleTensorScalar( + SimpleTensorScalar(kernels::EltwiseType::SUM, + {1, 1, 1, 1}, {1}, 1, {2}); + SimpleTensorScalar( + kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {0, 1, 2, 3, 4, 5}); + SimpleTensorScalar( + kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2, + {2, 4, 6, 8, 10, 12}); + SimpleTensorScalar( + kernels::EltwiseType::DIV, {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, 2, + {1, 2, 3, 4, 5, 6}); + SimpleTensorScalar( + kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {1, 1, 1, 1, 1, 1}); + SimpleTensorScalar( + kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3, + {3, 3, 3, 4, 5, 6}); + SimpleTensorScalar( + kernels::EltwiseType::NEG, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3, + {-1, -2, -3, -4, -5, -6}); + SimpleTensorScalar( kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3, {1, 2, 3, 4, 5, 6}); - SimpleTensorScalar(kernels::EltwiseType::SQR_DIFF, - {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, - 1, {0, 1, 4, 9, 16, 25}); + SimpleTensorScalar( + kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1, + {0, 1, 4, 9, 16, 25}); } TEST_F(EltwiseOpTest, CPUSimpleTensorVector) { - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3}, {1, 2, 3}, {2, 4, 6, 5, 7, 9}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); + SimpleTensorEltwise( + kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, + {1, 1, 1, 3}, {1, 2, 3}, {1, 1, 1, 0, 0, 0}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {3}, {1, 2, 3}, {2, 4, 6, 5, 7, 9}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::PROD, {3}, {1, 2, 3}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MIN, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - SimpleTensorEltwise( - kernels::EltwiseType::SQR_DIFF, {5}, {1, 2, 3, 4, 5}, - {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, - {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); + SimpleTensorEltwise( + kernels::EltwiseType::SQR_DIFF, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); + SimpleTensorEltwise( + kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {3}, + {1, 2, 3}, {1, 1, 1, 0, 0, 0}); } TEST_F(EltwiseOpTest, GPUSimpleTensorVector) { - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3}, {1, 2, 3}, {2, 4, 6, 5, 7, 9}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); } TEST_F(EltwiseOpTest, CPUSimpleTensorTensor) { - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); + SimpleTensorEltwise( + kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, + {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1}); } TEST_F(EltwiseOpTest, GPUSimpleTensorTensor) { - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); - SimpleTensorEltwise( + SimpleTensorEltwise( kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); @@ -595,27 +668,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) { } TEST_F(EltwiseOpTest, TensorGeneralBroadcast) { - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {2, 3, 4, 6, 7, 8}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 2, 2, 3, 4}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 8, 10, 12}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 2, 2.5, 3}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 1, 1, 2, 2, 2}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); - TensorGeneralBroadcastEltwise( + TensorGeneralBroadcastEltwise( kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 4, 4, 9, 16}); + TensorGeneralBroadcastEltwise( + kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, + {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 0, 0, 0, 0, 0}); } } // namespace test diff --git a/mace/ops/identity.cc b/mace/ops/identity.cc index ed89561a231d08438f35ae1ec53ecf45c0c806b5..628bfd2d593ec9b817221bc9e5852d3a2ceeef49 100644 --- a/mace/ops/identity.cc +++ b/mace/ops/identity.cc @@ -23,6 +23,11 @@ void Register_Identity(OperatorRegistry *op_registry) { .TypeConstraint("T") .Build(), IdentityOp); + MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + IdentityOp); #ifdef MACE_ENABLE_OPENCL MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 8662367eddc775a2142878430c39cc93c364ba15..322f1135d14ce281ecc14baf10bb2eb102e9a8d6 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -38,12 +38,12 @@ void TestStridedSlice(const std::vector &input_shape, OpsTestNet net; net.AddInputFromArray("Input", input_shape, input); net.AddInputFromArray( - "BeginIndices", {static_cast(input_shape.size())}, + "BeginIndices", {static_cast(begin_indices.size())}, begin_indices); net.AddInputFromArray( - "EndIndices", {static_cast(input_shape.size())}, end_indices); + "EndIndices", {static_cast(end_indices.size())}, end_indices); net.AddInputFromArray( - "Strides", {static_cast(input_shape.size())}, strides); + "Strides", {static_cast(strides.size())}, strides); OpDefBuilder("StridedSlice", "StridedSliceOpTest") .Input("Input") @@ -130,6 +130,8 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank1) { TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, {2, 3}, {1, 2, 3, 4, 5, 6}); + TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0}, {2}, {1}, 0, 0, 0, + 0, 0, {2, 3}, {1, 2, 3, 4, 5, 6}); TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, 0, 0, {1, 2}, {5, 6}); TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index b3731b9803524f6c06f96de5548199a7295715fa..956c9d31d996f93bf0b6a337cc56318247ecc129 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -66,11 +66,13 @@ class EltwiseType(Enum): ABS = 7 SQR_DIFF = 8 POW = 9 + EQUAL = 10 MaceSupportedOps = [ 'Activation', 'AddN', + 'ArgMax', 'BatchNorm', 'BatchToSpaceND', 'BiasAdd', diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 99c077c9994bcf97e69301f9d337191321e39658..0f5f3e21fb4b20cb6230f96b22695136bcd6415e 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -62,6 +62,7 @@ TFSupportedOps = [ 'Square', 'SquaredDifference', 'Rsqrt', + 'Equal', 'Relu', 'Relu6', 'Tanh', @@ -93,6 +94,7 @@ TFSupportedOps = [ 'Stack', 'Pack', 'Cast', + 'ArgMax', ] TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) @@ -125,7 +127,8 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.RealDiv.name: EltwiseType.DIV, TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.Square.name: EltwiseType.POW, - TFOpType.Rsqrt.name: EltwiseType.POW + TFOpType.Rsqrt.name: EltwiseType.POW, + TFOpType.Equal.name: EltwiseType.EQUAL, } activation_type = { TFOpType.Relu.name: ActivationType.RELU, @@ -153,6 +156,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.SquaredDifference.name: self.convert_elementwise, TFOpType.Square.name: self.convert_elementwise, TFOpType.Rsqrt.name: self.convert_elementwise, + TFOpType.Equal.name: self.convert_elementwise, TFOpType.Relu.name: self.convert_activation, TFOpType.Relu6.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation, @@ -183,7 +187,8 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Slice.name: self.convert_slice, TFOpType.Pack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack, - TFOpType.Cast.name: self.convert_cast + TFOpType.Cast.name: self.convert_cast, + TFOpType.ArgMax.name: self.convert_argmax, } self._option = option self._mace_net_def = mace_pb2.NetDef() @@ -376,18 +381,29 @@ class TensorflowConverter(base_converter.ConverterInterface): if type_arg.i != EltwiseType.NEG.value \ and type_arg.i != EltwiseType.ABS.value: - if len(tf_op.inputs[0].shape) == 0: - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str - value_arg.f = tf_op.inputs[0].eval().astype(np.float32) - self._skip_tensor.add(tf_op.inputs[0].name) - del op.input[0] - elif len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0: - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_value_str - value_arg.f = tf_op.inputs[1].eval().astype(np.float32) - self._skip_tensor.add(tf_op.inputs[1].name) - del op.input[1] + try: + def is_commutative(eltwise_type): + return EltwiseType(eltwise_type) in [ + EltwiseType.SUM, EltwiseType.PROD, + EltwiseType.MAX, EltwiseType.MIN] + + if len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0: + scalar = tf_op.inputs[1].eval().astype(np.float32) + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_value_str + value_arg.f = scalar + self._skip_tensor.add(tf_op.inputs[1].name) + del op.input[1] + elif len(tf_op.inputs[0].shape) == 0 and \ + is_commutative(type_arg.i): + scalar = tf_op.inputs[0].eval().astype(np.float32) + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_value_str + value_arg.f = scalar + self._skip_tensor.add(tf_op.inputs[0].name) + del op.input[0] + except tf.errors.InvalidArgumentError: + pass def convert_biasadd(self, tf_op): op = self.convert_general_op(tf_op) @@ -550,7 +566,13 @@ class TensorflowConverter(base_converter.ConverterInterface): transpose_a_arg.name = MaceKeyword.mace_transpose_a_str transpose_a_arg.i = int(adj_x) except ValueError: - pass + try: + transpose_a = tf_op.get_attr('transpose_a') + transpose_a_arg = op.arg.add() + transpose_a_arg.name = MaceKeyword.mace_transpose_a_str + transpose_a_arg.i = int(transpose_a) + except ValueError: + pass try: adj_y = tf_op.get_attr('adj_y') @@ -558,7 +580,13 @@ class TensorflowConverter(base_converter.ConverterInterface): transpose_b_arg.name = MaceKeyword.mace_transpose_b_str transpose_b_arg.i = int(adj_y) except ValueError: - pass + try: + transpose_b = tf_op.get_attr('transpose_b') + transpose_b_arg = op.arg.add() + transpose_b_arg.name = MaceKeyword.mace_transpose_b_str + transpose_b_arg.i = int(transpose_b) + except ValueError: + pass def convert_shape(self, tf_op): op = self.convert_general_op(tf_op) @@ -689,14 +717,18 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) op.type = MaceOp.Cast.name - data_type_arg = ConverterUtil.get_arg(op, 'T') try: dtype = tf_op.get_attr('DstT') if dtype == tf.int32: - data_type_arg.i = mace_pb2.DT_INT32 + op.output_type.extend([mace_pb2.DT_INT32]) elif dtype == tf.float32: - data_type_arg.i = self._option.data_type + op.output_type.extend([self._option.data_type]) else: mace_check(False, "data type %s not supported" % dtype) except ValueError: - data_type_arg.i = self._option.data_type + op.output_type.extend([self._option.data_type]) + + def convert_argmax(self, tf_op): + op = self.convert_general_op(tf_op) + op.type = MaceOp.ArgMax.name + op.output_type.extend([mace_pb2.DT_INT32]) diff --git a/mace/tools/validation/mace_run.cc b/mace/tools/validation/mace_run.cc index 2bf71bc384dd1fc856c9e5f57de92799934fa55b..cba3a926622d8f0867ff1e54ac0425f9752e456f 100644 --- a/mace/tools/validation/mace_run.cc +++ b/mace/tools/validation/mace_run.cc @@ -48,6 +48,10 @@ namespace str_util { std::vector Split(const std::string &str, char delims) { std::vector result; + if (str.empty()) { + result.push_back(""); + return result; + } std::string tmp = str; while (!tmp.empty()) { size_t next_offset = tmp.find(delims); diff --git a/tools/sh_commands.py b/tools/sh_commands.py index df23c488d1d59f7fa8f68e7ae4828a633273981d..92752c5790612791c17595b0fdb83a9fa486aee0 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -773,11 +773,17 @@ def tuning_run(abi, (phone_data_dir, os.path.basename(opencl_binary_file)), ]) adb_cmd = ' '.join(adb_cmd) + adb_cmd_file = "%s/%s" % (phone_data_dir, 'cmd_file') + with open('/tmp/mace_cmd_file', 'w') as cmd_file: + cmd_file.write(adb_cmd) + adb_push('/tmp/mace_cmd_file', adb_cmd_file, serialno) + sh.adb( "-s", serialno, "shell", - adb_cmd, + "sh", + adb_cmd_file, _tty_in=True, _out=process_output, _err_to_out=True) @@ -1159,10 +1165,7 @@ def benchmark_model(abi, phone_data_dir, serialno) - sh.adb( - "-s", - serialno, - "shell", + adb_cmd = [ "LD_LIBRARY_PATH=%s" % phone_data_dir, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_RUN_PARAMETER_PATH=%s/mace_run.config" % @@ -1185,6 +1188,19 @@ def benchmark_model(abi, "--model_file=%s" % mace_model_phone_path, "--opencl_binary_file=%s/%s" % (phone_data_dir, os.path.basename(opencl_binary_file)), + ] + adb_cmd = ' '.join(adb_cmd) + adb_cmd_file = "%s/%s" % (phone_data_dir, 'cmd_file') + with open('/tmp/mace_cmd_file', 'w') as cmd_file: + cmd_file.write(adb_cmd) + adb_push('/tmp/mace_cmd_file', adb_cmd_file, serialno) + + sh.adb( + "-s", + serialno, + "shell", + "sh", + adb_cmd_file, _fg=True) print("Benchmark done!\n")