提交 d6adf881 编写于 作者: 李寅

Implement argmax, equal op;

Revise memory allocation logic to make non-float output possible
上级 7ac05858
...@@ -7,22 +7,25 @@ Operator lists ...@@ -7,22 +7,25 @@ Operator lists
:header: "Operator","Android NN","Supported","Remark" :header: "Operator","Android NN","Supported","Remark"
"AVERAGE_POOL_2D","Y","Y","" "AVERAGE_POOL_2D","Y","Y",""
"ARGMAX","","Y","Only CPU and tensorflow is supported"
"BATCH_NORM","","Y","Fusion with activation is supported" "BATCH_NORM","","Y","Fusion with activation is supported"
"BATCH_TO_SPACE_ND","Y","Y","" "BATCH_TO_SPACE_ND","Y","Y",""
"BIAS_ADD","","Y","" "BIAS_ADD","","Y",""
"CAST","","Y","Only CPU and tensorflow model is supported"
"CHANNEL_SHUFFLE","","Y","" "CHANNEL_SHUFFLE","","Y",""
"CONCATENATION","Y","Y","Only support channel axis concatenation" "CONCATENATION","Y","Y","Only support channel axis concatenation"
"CONV_2D","Y","Y","Fusion with BN and activation layer is supported" "CONV_2D","Y","Y","Fusion with BN and activation layer is supported"
"DECONV_2D","N","Y","Only tensorflow model is supported" "DECONV_2D","","Y","Only tensorflow model is supported"
"DEPTHWISE_CONV_2D","Y","Y","Only multiplier = 1 is supported; Fusion is supported" "DEPTHWISE_CONV_2D","Y","Y","Only multiplier = 1 is supported; Fusion is supported"
"DEPTH_TO_SPACE","Y","Y","" "DEPTH_TO_SPACE","Y","Y",""
"DEQUANTIZE","Y","Y","Model quantization will be supported later" "DEQUANTIZE","Y","Y","Model quantization will be supported later"
"ELEMENT_WISE","Y","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW" "ELEMENT_WISE","Y","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL"
"EMBEDDING_LOOKUP","Y","","" "EMBEDDING_LOOKUP","Y","Y","Only support channel axis concatenation"
"FLOOR","Y","","" "FLOOR","Y","",""
"FULLY_CONNECTED","Y","Y","" "FULLY_CONNECTED","Y","Y",""
"GROUP_CONV_2D","","","Caffe model with group count = channel count is supported" "GROUP_CONV_2D","","","Caffe model with group count = channel count is supported"
"HASHTABLE_LOOKUP","Y","","" "HASHTABLE_LOOKUP","Y","",""
"IDENTITY","","Y","Only tensorflow model is supported"
"L2_NORMALIZATION","Y","","" "L2_NORMALIZATION","Y","",""
"L2_POOL_2D","Y","","" "L2_POOL_2D","Y","",""
"LOCAL_RESPONSE_NORMALIZATION","Y","Y","" "LOCAL_RESPONSE_NORMALIZATION","Y","Y",""
...@@ -31,9 +34,10 @@ Operator lists ...@@ -31,9 +34,10 @@ Operator lists
"LSTM","Y","","" "LSTM","Y","",""
"MATMUL","","Y","" "MATMUL","","Y",""
"MAX_POOL_2D","Y","Y","" "MAX_POOL_2D","Y","Y",""
"PAD", "N","Y","" "PAD", "Y","Y",""
"PSROI_ALIGN","","Y","" "PSROI_ALIGN","","Y",""
"PRELU","","Y","Only caffe model is supported" "PRELU","","Y","Only caffe model is supported"
"REDUCE_MEAN","Y","Y","Only tensorflow model is supported"
"RELU","Y","Y","" "RELU","Y","Y",""
"RELU1","Y","Y","" "RELU1","Y","Y",""
"RELU6","Y","Y","" "RELU6","Y","Y",""
...@@ -42,9 +46,14 @@ Operator lists ...@@ -42,9 +46,14 @@ Operator lists
"RESIZE_BILINEAR","Y","Y","" "RESIZE_BILINEAR","Y","Y",""
"RNN","Y","","" "RNN","Y","",""
"RPN_PROPOSAL_LAYER","","Y","" "RPN_PROPOSAL_LAYER","","Y",""
"SLICE","N","Y","Only support channel axis slice" "SHAPE","","Y","Only CPU and tensorflow is supported"
"STACK","","Y","Only CPU and tensorflow is supported"
"STRIDEDSLICE","Y","Y","Only CPU and tensorflow is supported"
"SLICE","","Y","In tensorflow, this op is equivalent to SPLIT; Only support channel axis slice"
"SOFTMAX","Y","Y","" "SOFTMAX","Y","Y",""
"SPACE_TO_BATCH_ND","Y", "Y","" "SPACE_TO_BATCH_ND","Y", "Y",""
"SPACE_TO_DEPTH","Y","Y","" "SPACE_TO_DEPTH","Y","Y",""
"SQEEZE","Y","Y","Only CPU and tensorflow is supported"
"SVDF","Y","","" "SVDF","Y","",""
"TANH","Y","Y","" "TANH","Y","Y",""
"TRANSPOSE","Y","Y","Only CPU and tensorflow is supported"
...@@ -264,7 +264,6 @@ MaceStatus MaceEngine::Impl::Run( ...@@ -264,7 +264,6 @@ MaceStatus MaceEngine::Impl::Run(
auto shape = output_tensor->shape(); auto shape = output_tensor->shape();
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0";
MACE_CHECK(shape == output.second.shape()) MACE_CHECK(shape == output.second.shape())
<< "Output shape mismatch: " << "Output shape mismatch: "
<< MakeString<int64_t>(output.second.shape()) << MakeString<int64_t>(output.second.shape())
......
...@@ -76,6 +76,7 @@ namespace ops { ...@@ -76,6 +76,7 @@ namespace ops {
// Keep in lexicographical order // Keep in lexicographical order
extern void Register_Activation(OperatorRegistry *op_registry); extern void Register_Activation(OperatorRegistry *op_registry);
extern void Register_AddN(OperatorRegistry *op_registry); extern void Register_AddN(OperatorRegistry *op_registry);
extern void Register_ArgMax(OperatorRegistry *op_registry);
extern void Register_BatchNorm(OperatorRegistry *op_registry); extern void Register_BatchNorm(OperatorRegistry *op_registry);
extern void Register_BatchToSpaceND(OperatorRegistry *op_registry); extern void Register_BatchToSpaceND(OperatorRegistry *op_registry);
extern void Register_BiasAdd(OperatorRegistry *op_registry); extern void Register_BiasAdd(OperatorRegistry *op_registry);
...@@ -124,6 +125,7 @@ OperatorRegistry::OperatorRegistry() { ...@@ -124,6 +125,7 @@ OperatorRegistry::OperatorRegistry() {
// Keep in lexicographical order // Keep in lexicographical order
ops::Register_Activation(this); ops::Register_Activation(this);
ops::Register_AddN(this); ops::Register_AddN(this);
ops::Register_ArgMax(this);
ops::Register_BatchNorm(this); ops::Register_BatchNorm(this);
ops::Register_BatchToSpaceND(this); ops::Register_BatchToSpaceND(this);
ops::Register_BiasAdd(this); ops::Register_BiasAdd(this);
......
...@@ -157,6 +157,8 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -157,6 +157,8 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
} }
} }
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid."); MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
// TODO(liyin): memory block should not have concept of type, but to be
// consistent with gpu, all memory block use float/half as unit
for (auto &mem_block : net_def.mem_arena().mem_block()) { for (auto &mem_block : net_def.mem_arena().mem_block()) {
if (device_type == DeviceType::GPU) { if (device_type == DeviceType::GPU) {
// TODO(liuqi): refactor based on PB // TODO(liuqi): refactor based on PB
...@@ -191,8 +193,15 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, ...@@ -191,8 +193,15 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
auto mem_ids = op.mem_id(); auto mem_ids = op.mem_id();
int count = mem_ids.size(); int count = mem_ids.size();
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
DataType output_type;
if (i < op.output_type_size()) {
output_type = op.output_type(i);
} else {
output_type = dtype;
}
std::unique_ptr<Tensor> tensor std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]), dtype)); (new Tensor(preallocated_allocator_.GetBuffer(mem_ids[i]),
output_type));
tensor->SetSourceOpName(op.name()); tensor->SetSourceOpName(op.name());
if (device_type == DeviceType::GPU) { if (device_type == DeviceType::GPU) {
VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")" VLOG(3) << "Tensor: " << op.name() << "(" << op.type() << ")"
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_ARGMAX_H_
#define MACE_KERNELS_ARGMAX_H_
#include <algorithm>
#include <functional>
#include <limits>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
template <DeviceType D, typename T>
struct ArgMaxFunctor {
MaceStatus operator()(const Tensor *input,
const Tensor *axis,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar");
MACE_CHECK(axis->dim_size() == 0, "Mace argmax only supports scalar axis");
Tensor::MappingGuard axis_guard(axis);
int axis_value = axis->data<int32_t>()[0];
if (axis_value < 0) {
axis_value += input->dim_size();
}
MACE_CHECK(axis_value == input->dim_size() - 1,
"Mace argmax only supports last dimension as axis");
std::vector<index_t> output_shape(input->dim_size() - 1);
for (index_t d = 0; d < input->dim_size() - 1; ++d) {
output_shape[d] = input->dim(d < axis_value ? d : d + 1);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<T>();
auto output_data = output->mutable_data<int32_t>();
index_t outer_size = output->size();
index_t inner_size = input->dim(axis_value);
#pragma omp parallel for
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T max_value = std::numeric_limits<T>::lowest();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
if (input_ptr[j] > max_value) {
max_value = input_ptr[j];
idx = j;
}
}
output_data[i] = idx;
}
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_ARGMAX_H_
...@@ -16,10 +16,10 @@ ...@@ -16,10 +16,10 @@
#define MACE_KERNELS_ELTWISE_H_ #define MACE_KERNELS_ELTWISE_H_
#include <algorithm> #include <algorithm>
#include <functional>
#include <memory> #include <memory>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <functional>
#include "mace/core/future.h" #include "mace/core/future.h"
#include "mace/core/tensor.h" #include "mace/core/tensor.h"
...@@ -42,9 +42,12 @@ enum EltwiseType { ...@@ -42,9 +42,12 @@ enum EltwiseType {
ABS = 7, ABS = 7,
SQR_DIFF = 8, SQR_DIFF = 8,
POW = 9, POW = 9,
NONE = 10, EQUAL = 10,
NONE = 11,
}; };
static bool IsLogicalType(EltwiseType type) { return type == EQUAL; }
inline index_t GetIndex(const std::vector<index_t> &shape, inline index_t GetIndex(const std::vector<index_t> &shape,
const std::vector<index_t> &index) { const std::vector<index_t> &index) {
index_t idx = 0; index_t idx = 0;
...@@ -68,22 +71,19 @@ inline void IncreaseIndex(const std::vector<index_t> &shape, ...@@ -68,22 +71,19 @@ inline void IncreaseIndex(const std::vector<index_t> &shape,
} }
} }
inline void TensorGeneralBroadcastEltwise(const EltwiseType type, template <typename T, typename DstType>
const float *input0, inline void TensorGeneralBroadcastEltwise(
const float *input1, const EltwiseType type,
const std::vector<float> &coeff, const T *input0,
const bool swapped, const T *input1,
const std::vector<index_t> const std::vector<float> &coeff,
&input0_shape, const bool swapped,
const std::vector<index_t> const std::vector<index_t> &input0_shape,
&input1_shape, const std::vector<index_t> &input1_shape,
const std::vector<index_t> const std::vector<index_t> &output_shape,
&output_shape, DstType *output) {
float *output) { const index_t output_size = std::accumulate(
const index_t output_size = std::accumulate(output_shape.begin(), output_shape.begin(), output_shape.end(), 1, std::multiplies<index_t>());
output_shape.end(),
1,
std::multiplies<index_t>());
std::vector<index_t> out_index(output_shape.size(), 0); std::vector<index_t> out_index(output_shape.size(), 0);
switch (type) { switch (type) {
case SUM: case SUM:
...@@ -191,19 +191,28 @@ inline void TensorGeneralBroadcastEltwise(const EltwiseType type, ...@@ -191,19 +191,28 @@ inline void TensorGeneralBroadcastEltwise(const EltwiseType type,
} }
} }
break; break;
case EQUAL:
for (index_t i = 0; i < output_size; ++i) {
const index_t idx0 = GetIndex(input0_shape, out_index);
const index_t idx1 = GetIndex(input1_shape, out_index);
output[i] = input1[idx1] == input0[idx0];
IncreaseIndex(output_shape, &out_index);
}
break;
default: default:
LOG(FATAL) << "Eltwise op not support type " << type; LOG(FATAL) << "Eltwise op not support type " << type;
} }
} }
template <typename T, typename DstType>
inline void TensorBroadcastEltwise(const EltwiseType type, inline void TensorBroadcastEltwise(const EltwiseType type,
const float *input0, const T *input0,
const float *input1, const T *input1,
const std::vector<float> &coeff, const std::vector<float> &coeff,
const index_t diff_size, const index_t diff_size,
const index_t common_size, const index_t common_size,
const bool swapped, const bool swapped,
float *output) { DstType *output) {
switch (type) { switch (type) {
case SUM: case SUM:
if (coeff.empty()) { if (coeff.empty()) {
...@@ -333,19 +342,29 @@ inline void TensorBroadcastEltwise(const EltwiseType type, ...@@ -333,19 +342,29 @@ inline void TensorBroadcastEltwise(const EltwiseType type,
output[i] = std::fabs(input0[i]); output[i] = std::fabs(input0[i]);
} }
break; break;
case EQUAL:
#pragma omp parallel for collapse(2)
for (index_t d = 0; d < diff_size; ++d) {
for (index_t i = 0; i < common_size; ++i) {
output[i + d * common_size] =
input0[i + d * common_size] == input1[i];
}
}
break;
default: default:
LOG(FATAL) << "Eltwise op not support type " << type; LOG(FATAL) << "Eltwise op not support type " << type;
} }
} }
// Multiplication is costly, so we specialize the following case. // Multiplication is costly, so we specialize the following case.
template <typename T, typename DstType>
inline void TensorEltwise(const EltwiseType type, inline void TensorEltwise(const EltwiseType type,
const float *input0, const T *input0,
const float *input1, const T *input1,
const std::vector<float> &coeff, const std::vector<float> &coeff,
const index_t size, const index_t size,
const bool swapped, const bool swapped,
float *output) { DstType *output) {
switch (type) { switch (type) {
case SUM: case SUM:
if (coeff.empty()) { if (coeff.empty()) {
...@@ -445,19 +464,26 @@ inline void TensorEltwise(const EltwiseType type, ...@@ -445,19 +464,26 @@ inline void TensorEltwise(const EltwiseType type,
output[i] = std::fabs(input0[i]); output[i] = std::fabs(input0[i]);
} }
break; break;
case EQUAL:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output[i] = input0[i] == input1[i];
}
break;
default: default:
LOG(FATAL) << "Eltwise op not support type " << type; LOG(FATAL) << "Eltwise op not support type " << type;
} }
} }
// Multiplication is costly, so we specialize the following case. // Multiplication is costly, so we specialize the following case.
template <typename T, typename DstType>
inline void TensorScalarEltwise(const EltwiseType type, inline void TensorScalarEltwise(const EltwiseType type,
const float *input0, const T *input0,
const float input1, const T input1,
const std::vector<float> &coeff, const std::vector<float> &coeff,
const index_t size, const index_t size,
const bool swapped, const bool swapped,
float *output) { DstType *output) {
switch (type) { switch (type) {
case SUM: case SUM:
if (coeff.empty()) { if (coeff.empty()) {
...@@ -556,31 +582,39 @@ inline void TensorScalarEltwise(const EltwiseType type, ...@@ -556,31 +582,39 @@ inline void TensorScalarEltwise(const EltwiseType type,
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
output[i] = std::fabs(input0[i]); output[i] = std::fabs(input0[i]);
} }
break;
case EQUAL:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output[i] = input0[i] == input1;
}
break; break;
default: default:
LOG(FATAL) << "Eltwise op not support type " << type; LOG(FATAL) << "Eltwise op not support type " << type;
} }
} }
template <typename T, typename DstType>
inline void TensorEltwisePerChannel(const EltwiseType type, inline void TensorEltwisePerChannel(const EltwiseType type,
const float *input0, const T *input0,
const float *input1, const T *input1,
const std::vector<float> &coeff, const std::vector<float> &coeff,
const index_t batch0, const index_t batch0,
const index_t batch1, const index_t batch1,
const index_t channel, const index_t channel,
const index_t image_size, const index_t image_size,
const bool swapped, const bool swapped,
float *output) { DstType *output) {
switch (type) { switch (type) {
case SUM: case SUM:
if (coeff.empty()) { if (coeff.empty()) {
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in0_ptr[i] + in1_ptr[c]; out_ptr[i] = in0_ptr[i] + in1_ptr[c];
} }
...@@ -594,9 +628,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -594,9 +628,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = out_ptr[i] =
in0_ptr[i] * coeff_copy[0] + in1_ptr[c] * coeff_copy[1]; in0_ptr[i] * coeff_copy[0] + in1_ptr[c] * coeff_copy[1];
...@@ -610,9 +644,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -610,9 +644,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in0_ptr[i] - in1_ptr[c]; out_ptr[i] = in0_ptr[i] - in1_ptr[c];
} }
...@@ -622,9 +656,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -622,9 +656,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in1_ptr[c] - in0_ptr[i]; out_ptr[i] = in1_ptr[c] - in0_ptr[i];
} }
...@@ -636,9 +670,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -636,9 +670,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in0_ptr[i] * in1_ptr[c]; out_ptr[i] = in0_ptr[i] * in1_ptr[c];
} }
...@@ -650,9 +684,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -650,9 +684,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in0_ptr[i] / in1_ptr[c]; out_ptr[i] = in0_ptr[i] / in1_ptr[c];
} }
...@@ -662,9 +696,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -662,9 +696,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in1_ptr[c] / in0_ptr[i]; out_ptr[i] = in1_ptr[c] / in0_ptr[i];
} }
...@@ -676,9 +710,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -676,9 +710,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = std::min(in0_ptr[i], in1_ptr[c]); out_ptr[i] = std::min(in0_ptr[i], in1_ptr[c]);
} }
...@@ -689,9 +723,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -689,9 +723,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = std::max(in0_ptr[i], in1_ptr[c]); out_ptr[i] = std::max(in0_ptr[i], in1_ptr[c]);
} }
...@@ -702,9 +736,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -702,9 +736,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = std::pow(in0_ptr[i] - in1_ptr[c], 2.f); out_ptr[i] = std::pow(in0_ptr[i] - in1_ptr[c], 2.f);
} }
...@@ -716,9 +750,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -716,9 +750,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = std::pow(in0_ptr[i], in1_ptr[c]); out_ptr[i] = std::pow(in0_ptr[i], in1_ptr[c]);
} }
...@@ -728,9 +762,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -728,9 +762,9 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
#pragma omp parallel for collapse(2) #pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) { for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) { for (index_t c = 0; c < channel; ++c) {
const float *in0_ptr = input0 + ((b * channel) + c) * image_size; const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const float *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0); const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
float *out_ptr = output + ((b * channel) + c) * image_size; DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) { for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = std::pow(in1_ptr[c], in0_ptr[i]); out_ptr[i] = std::pow(in1_ptr[c], in0_ptr[i]);
} }
...@@ -750,6 +784,19 @@ inline void TensorEltwisePerChannel(const EltwiseType type, ...@@ -750,6 +784,19 @@ inline void TensorEltwisePerChannel(const EltwiseType type,
output[i] = std::fabs(input0[i]); output[i] = std::fabs(input0[i]);
} }
break; break;
case EQUAL:
#pragma omp parallel for collapse(2)
for (index_t b = 0; b < batch0; ++b) {
for (index_t c = 0; c < channel; ++c) {
const T *in0_ptr = input0 + ((b * channel) + c) * image_size;
const T *in1_ptr = input1 + (batch1 > 1 ? b * channel : 0);
DstType *out_ptr = output + ((b * channel) + c) * image_size;
for (index_t i = 0; i < image_size; ++i) {
out_ptr[i] = in0_ptr[i] == in1_ptr[c];
}
}
}
break;
default: default:
LOG(FATAL) << "Eltwise op not support type " << type; LOG(FATAL) << "Eltwise op not support type " << type;
} }
...@@ -769,30 +816,17 @@ struct EltwiseFunctorBase { ...@@ -769,30 +816,17 @@ struct EltwiseFunctorBase {
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct EltwiseFunctor; struct EltwiseFunctor : EltwiseFunctorBase {
template <>
struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
EltwiseFunctor(const EltwiseType type, EltwiseFunctor(const EltwiseType type,
const std::vector<float> &coeff, const std::vector<float> &coeff,
const float value, const float value, // keep it float as it comes from arg
const DataFormat data_format) const DataFormat data_format)
: EltwiseFunctorBase(type, coeff, value, data_format) {} : EltwiseFunctorBase(type, coeff, value, data_format) {}
MaceStatus operator()(const Tensor *input0, template <typename DstType>
const Tensor *input1, MaceStatus DoEltwise(const Tensor *input0,
Tensor *output, const Tensor *input1,
StatsFuture *future) { Tensor *output) {
MACE_UNUSED(future);
if (input1 == nullptr) {
scalar_tensor_.Resize({});
Tensor::MappingGuard guard(&scalar_tensor_);
auto scalar_data = scalar_tensor_.mutable_data<float>();
scalar_data[0] = value_;
input1 = &scalar_tensor_;
}
bool swapped = false; bool swapped = false;
if (input0->size() < input1->size()) { if (input0->size() < input1->size()) {
std::swap(input0, input1); std::swap(input0, input1);
...@@ -804,20 +838,16 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -804,20 +838,16 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
static_cast<uint32_t>(input0->dim_size() - input1->dim_size()); static_cast<uint32_t>(input0->dim_size() - input1->dim_size());
if (data_format_ == NCHW) { if (data_format_ == NCHW) {
MACE_CHECK( MACE_CHECK(
(input0->dim_size() == 4) (input0->dim_size() == 4) &&
&& ((input1->dim_size() == 0) ((input1->dim_size() == 0) ||
|| (input1->dim_size() == 4 (input1->dim_size() == 4 && input1->dim(1) == input0->dim(1) &&
&& input1->dim(1) == input0->dim(1) (input1->dim(0) == input0->dim(0) || input1->dim(0) == 1)) ||
&& (input1->dim(0) == input0->dim(0) (input1->dim_size() == 1 && input1->dim(0) == input0->dim(1))),
|| input1->dim(0) == 1))
|| (input1->dim_size() == 1
&& input1->dim(0) == input0->dim(1))),
"only support broadcast channel dimension"); "only support broadcast channel dimension");
} else { } else {
for (uint32_t i = 0; i < input1->dim_size(); ++i) { for (uint32_t i = 0; i < input1->dim_size(); ++i) {
MACE_CHECK(input0->dim(rank_diff + i) == 1 MACE_CHECK(input0->dim(rank_diff + i) == 1 || input1->dim(i) == 1 ||
|| input1->dim(i) == 1 input0->dim(rank_diff + i) == input1->dim(i),
|| input0->dim(rank_diff + i) == input1->dim(i),
"Element-Wise op only support tail dimensions broadcast"); "Element-Wise op only support tail dimensions broadcast");
} }
} }
...@@ -825,14 +855,14 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -825,14 +855,14 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
Tensor::MappingGuard input0_guard(input0); Tensor::MappingGuard input0_guard(input0);
Tensor::MappingGuard input1_guard(input1); Tensor::MappingGuard input1_guard(input1);
const float *input0_ptr = input0->data<float>(); const T *input0_ptr = input0->data<T>();
const float *input1_ptr = input1->data<float>(); const T *input1_ptr = input1->data<T>();
if (data_format_ == NCHW && input1->dim_size() > 0 && if (data_format_ == NCHW && input1->dim_size() > 0 &&
input1->size() < input0->size()) { input1->size() < input0->size()) {
MACE_RETURN_IF_ERROR(output->ResizeLike(input0)); MACE_RETURN_IF_ERROR(output->ResizeLike(input0));
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
float *output_ptr = output->mutable_data<float>(); DstType *output_ptr = output->mutable_data<DstType>();
TensorEltwisePerChannel( TensorEltwisePerChannel(
type_, input0_ptr, input1_ptr, coeff_, input0->dim(0), type_, input0_ptr, input1_ptr, coeff_, input0->dim(0),
input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1), input1->dim_size() == 1 ? 1 : input1->dim(0), input0->dim(1),
...@@ -841,8 +871,7 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -841,8 +871,7 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
} else { } else {
const std::vector<index_t> &input0_shape = input0->shape(); const std::vector<index_t> &input0_shape = input0->shape();
std::vector<index_t> input1_shape(rank_diff, 1); std::vector<index_t> input1_shape(rank_diff, 1);
input1_shape.insert(input1_shape.end(), input1_shape.insert(input1_shape.end(), input1->shape().begin(),
input1->shape().begin(),
input1->shape().end()); input1->shape().end());
std::vector<index_t> output_shape(input0->dim_size(), 0); std::vector<index_t> output_shape(input0->dim_size(), 0);
...@@ -851,27 +880,21 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -851,27 +880,21 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
} }
MACE_RETURN_IF_ERROR(output->Resize(output_shape)); MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
float *output_ptr = output->mutable_data<float>(); DstType *output_ptr = output->mutable_data<DstType>();
bool need_general_broadcast = false; bool need_general_broadcast = false;
for (uint32_t i = 0; i < input1->dim_size(); ++i) { for (uint32_t i = 0; i < input1->dim_size(); ++i) {
if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1) if ((input0->dim(rank_diff + i) == 1 && input1->dim(i) > 1) ||
|| (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) { (input0->dim(rank_diff + i) > 1 && input1->dim(i) == 1)) {
need_general_broadcast = true; need_general_broadcast = true;
break; break;
} }
} }
if (need_general_broadcast) { if (need_general_broadcast) {
TensorGeneralBroadcastEltwise(type_, TensorGeneralBroadcastEltwise(type_, input0_ptr, input1_ptr, coeff_,
input0_ptr, swapped, input0_shape, input1_shape,
input1_ptr, output_shape, output_ptr);
coeff_,
swapped,
input0_shape,
input1_shape,
output_shape,
output_ptr);
} else if (input1->size() == input0->size()) { } else if (input1->size() == input0->size()) {
TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(), TensorEltwise(type_, input0_ptr, input1_ptr, coeff_, input0->size(),
swapped, output_ptr); swapped, output_ptr);
...@@ -891,6 +914,28 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase { ...@@ -891,6 +914,28 @@ struct EltwiseFunctor<DeviceType::CPU, float> : EltwiseFunctorBase {
return MACE_SUCCESS; return MACE_SUCCESS;
} }
MaceStatus operator()(const Tensor *input0,
const Tensor *input1,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
if (input1 == nullptr) {
scalar_tensor_.Resize({});
Tensor::MappingGuard guard(&scalar_tensor_);
auto scalar_data = scalar_tensor_.mutable_data<T>();
scalar_data[0] = static_cast<T>(value_);
input1 = &scalar_tensor_;
}
if (IsLogicalType(type_)) {
// as we do not have bool-type tensor, we use int type
return DoEltwise<int32_t>(input0, input1, output);
} else {
return DoEltwise<T>(input0, input1, output);
}
}
Tensor scalar_tensor_; Tensor scalar_tensor_;
}; };
......
...@@ -67,6 +67,26 @@ struct StridedSliceFunctor { ...@@ -67,6 +67,26 @@ struct StridedSliceFunctor {
const T *input_data = input->data<T>(); const T *input_data = input->data<T>();
const int32_t *begin_indices_data = begin_indices->data<int32_t>(); const int32_t *begin_indices_data = begin_indices->data<int32_t>();
const int32_t *end_indices_data = end_indices->data<int32_t>(); const int32_t *end_indices_data = end_indices->data<int32_t>();
const int32_t *strides_data = strides->data<int32_t>();
std::vector<int32_t> pad_begin_indices(input->dim_size(), 0);
std::vector<int32_t> pad_end_indices(input->dim_size(), 0);
std::vector<int32_t> pad_strides_indices(input->dim_size(), 1);
if (begin_indices->size() < input->dim_size()) {
for (index_t i = 0; i < begin_indices->size(); ++i) {
pad_begin_indices[i] = begin_indices_data[i];
pad_end_indices[i] = end_indices_data[i];
pad_strides_indices[i] = strides_data[i];
}
for (index_t i = begin_indices->size(); i < input->dim_size(); ++i) {
pad_end_indices[i] = input->dim(i);
}
begin_indices_data = pad_begin_indices.data();
end_indices_data = pad_end_indices.data();
strides_data = pad_strides_indices.data();
}
std::vector<int32_t> slice_end_data; std::vector<int32_t> slice_end_data;
if (is_slice_) { if (is_slice_) {
// if this op is slice, the end_indices_data is size actually // if this op is slice, the end_indices_data is size actually
...@@ -80,7 +100,6 @@ struct StridedSliceFunctor { ...@@ -80,7 +100,6 @@ struct StridedSliceFunctor {
} }
end_indices_data = slice_end_data.data(); end_indices_data = slice_end_data.data();
} }
const int32_t *strides_data = strides->data<int32_t>();
std::vector<index_t> output_shape; std::vector<index_t> output_shape;
std::vector<index_t> real_begin_indices(input->dim_size(), 0); std::vector<index_t> real_begin_indices(input->dim_size(), 0);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/argmax.h"
namespace mace {
namespace ops {
void Register_ArgMax(OperatorRegistry *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ArgMax")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
ArgMaxOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_ARGMAX_H_
#define MACE_OPS_ARGMAX_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/argmax.h"
namespace mace {
namespace ops {
template<DeviceType D, class T>
class ArgMaxOp : public Operator<D, T> {
public:
ArgMaxOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(0);
const Tensor *axis = this->Input(1);
Tensor *output = this->Output(0);
return functor_(input, axis, output, future);
}
private:
kernels::ArgMaxFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(INPUT, AXIS);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_ARGMAX_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ArgMaxOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void ArgMaxTest(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<index_t> &output_shape,
const std::vector<int32_t> &output) {
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
net.AddInputFromArray<D, int32_t>("axis", {}, {-1});
if (D == DeviceType::CPU) {
OpDefBuilder("ArgMax", "ArgMaxTest")
.Input("Input")
.Input("axis")
.Output("Output")
.OutputType({DT_INT32})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
MACE_NOT_IMPLEMENTED;
}
// Check
auto expected = CreateTensor<int32_t>(output_shape, output);
ExpectTensorNear<int32_t>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(ArgMaxOpTest, Vector) { ArgMaxTest<CPU>({3}, {-3, -1, -2}, {}, {1}); }
TEST_F(ArgMaxOpTest, Matrix) {
ArgMaxTest<CPU>({3, 3}, {4, 5, 6, 9, 8, 7, 1, 2, 3}, {3}, {2, 0, 2});
}
TEST_F(ArgMaxOpTest, HighRank) {
ArgMaxTest<CPU>({1, 2, 2, 3}, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11},
{1, 2, 2}, {2, 2, 2, 2});
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -22,11 +22,11 @@ ...@@ -22,11 +22,11 @@
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, typename DT> template <DeviceType D, typename SrcType>
class CastOp : public Operator<D, DT> { class CastOp : public Operator<D, SrcType> {
public: public:
CastOp(const OperatorDef &op_def, Workspace *ws) CastOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, DT>(op_def, ws) {} : Operator<D, SrcType>(op_def, ws) {}
MaceStatus Run(StatsFuture *future) override { MaceStatus Run(StatsFuture *future) override {
MACE_UNUSED(future); MACE_UNUSED(future);
...@@ -36,17 +36,16 @@ class CastOp : public Operator<D, DT> { ...@@ -36,17 +36,16 @@ class CastOp : public Operator<D, DT> {
Tensor::MappingGuard input_guard(input); Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output); Tensor::MappingGuard output_guard(output);
auto src_dtype = input->dtype(); auto dst_dtype = output->dtype();
auto output_data = output->mutable_data<DT>();
#define MACE_CAST_COPY \ #define MACE_CAST_COPY \
auto input_data = input->data<T>(); \ auto output_data = output->mutable_data<T>(); \
auto input_data = input->data<SrcType>(); \
for (index_t i = 0; i < output->size(); ++i) { \ for (index_t i = 0; i < output->size(); ++i) { \
output_data[i] = static_cast<DT>(input_data[i]); \ output_data[i] = static_cast<T>(input_data[i]); \
} }
MACE_RUN_WITH_TYPE_ENUM(src_dtype, MACE_CAST_COPY); MACE_RUN_WITH_TYPE_ENUM(dst_dtype, MACE_CAST_COPY);
return MACE_SUCCESS; return MACE_SUCCESS;
} }
......
...@@ -30,8 +30,9 @@ void TestCast(const std::vector<index_t> &input_shape, ...@@ -30,8 +30,9 @@ void TestCast(const std::vector<index_t> &input_shape,
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Cast", "CastTest") OpDefBuilder("Cast", "CastTest")
.Input("Input") .Input("Input")
.OutputType({DataTypeToEnum<DstType>::v()})
.Output("Output") .Output("Output")
.AddIntArg("T", DataTypeToEnum<DstType>::v()) .AddIntArg("T", DataTypeToEnum<SrcType>::v())
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
...@@ -55,10 +56,12 @@ void TestCast(const std::vector<index_t> &input_shape, ...@@ -55,10 +56,12 @@ void TestCast(const std::vector<index_t> &input_shape,
TEST_F(CastOpTest, TestCastFromFloatToInt32) { TEST_F(CastOpTest, TestCastFromFloatToInt32) {
TestCast<float, int32_t>({1, 2, 3}, {1.1, 2.2, 3.3, 4.4, 5.5, 6.6}); TestCast<float, int32_t>({1, 2, 3}, {1.1, 2.2, 3.3, 4.4, 5.5, 6.6});
TestCast<float, int32_t>({}, {3.3});
} }
TEST_F(CastOpTest, TestCastFromInt32ToFloat) { TEST_F(CastOpTest, TestCastFromInt32ToFloat) {
TestCast<int32_t, float>({1, 2, 3}, {1, 2, 3, 4, 5, 6}); TestCast<int32_t, float>({1, 2, 3}, {1, 2, 3, 4, 5, 6});
TestCast<int32_t, float>({}, {3});
} }
} // namespace test } // namespace test
......
...@@ -23,6 +23,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) { ...@@ -23,6 +23,11 @@ void Register_Eltwise(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
EltwiseOp<DeviceType::CPU, float>); EltwiseOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
EltwiseOp<DeviceType::CPU, int32_t>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Eltwise")
......
...@@ -23,30 +23,63 @@ namespace test { ...@@ -23,30 +23,63 @@ namespace test {
class EltwiseOpTest : public OpsTestBase {}; class EltwiseOpTest : public OpsTestBase {};
namespace { namespace {
template <DeviceType D, typename T> template <DeviceType D, typename T, typename DstType>
void SimpleScalarScalar(const kernels::EltwiseType type,
const T input,
const float x,
const DstType output) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, T>("Input", {}, {input});
if (D == DeviceType::CPU) {
OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("value", x)
.OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
MACE_NOT_IMPLEMENTED;
}
auto expected = CreateTensor<DstType>({}, {output});
ExpectTensorNear<DstType>(*expected, *net.GetOutput("Output"), 1e-5);
}
template <DeviceType D, typename T, typename DstType>
void SimpleTensorScalar(const kernels::EltwiseType type, void SimpleTensorScalar(const kernels::EltwiseType type,
const std::vector<index_t> &shape, const std::vector<index_t> &shape,
const std::vector<float> &input, const std::vector<T> &input,
const float x, const float x,
const std::vector<float> &output) { const std::vector<DstType> &output) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddInputFromArray<D, float>("Input", shape, input); net.AddInputFromArray<D, T>("Input", shape, input);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "TInput", NCHW); net.TransformDataFormat<D, T>("Input", NHWC, "TInput", NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
.AddFloatArg("value", x) .AddFloatArg("value", x)
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("data_format", DataFormat::NCHW)
.OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput") .Output("TOutput")
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC);
} else { } else {
BufferToImage<D, T>(&net, "Input", "InputImg", BufferToImage<D, T>(&net, "Input", "InputImg",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -60,44 +93,47 @@ void SimpleTensorScalar(const kernels::EltwiseType type, ...@@ -60,44 +93,47 @@ void SimpleTensorScalar(const kernels::EltwiseType type,
// Run // Run
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "Output", ImageToBuffer<D, DstType>(&net, "OutputImg", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} }
auto expected = CreateTensor<float>(shape, output); auto expected = CreateTensor<DstType>(shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<DstType>(*expected, *net.GetOutput("Output"), 1e-5);
} }
template <DeviceType D, typename T> template <DeviceType D, typename T, typename DstType>
void SimpleTensorEltwise(const kernels::EltwiseType type, void SimpleTensorEltwise(const kernels::EltwiseType type,
const std::vector<index_t> &shape0, const std::vector<index_t> &shape0,
const std::vector<float> &input0, const std::vector<T> &input0,
const std::vector<index_t> &shape1, const std::vector<index_t> &shape1,
const std::vector<float> &input1, const std::vector<T> &input1,
const std::vector<float> &output, const std::vector<DstType> &output,
const std::vector<float> &coeff = {}) { const std::vector<float> &coeff = {}) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddInputFromArray<D, float>("Input0", shape0, input0); net.AddInputFromArray<D, T>("Input0", shape0, input0);
net.AddInputFromArray<D, float>("Input1", shape1, input1); net.AddInputFromArray<D, T>("Input1", shape1, input1);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest") auto op_builder =
.AddIntArg("type", static_cast<int>(type)) OpDefBuilder("Eltwise", "EltwiseTest")
.AddFloatsArg("coeff", coeff) .AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("data_format", DataFormat::NCHW) .AddIntArg("type", static_cast<int>(type))
.Output("TOutput"); .AddFloatsArg("coeff", coeff)
.AddIntArg("data_format", DataFormat::NCHW)
.OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput");
if (shape0.size() > 1) { if (shape0.size() > 1) {
net.TransformDataFormat<D, float>("Input0", NHWC, "TInput0", NCHW); net.TransformDataFormat<D, T>("Input0", NHWC, "TInput0", NCHW);
op_builder.Input("TInput0"); op_builder.Input("TInput0");
} else { } else {
op_builder.Input("Input0"); op_builder.Input("Input0");
} }
if (shape1.size() > 1) { if (shape1.size() > 1) {
net.TransformDataFormat<D, float>("Input1", NHWC, "TInput1", NCHW); net.TransformDataFormat<D, T>("Input1", NHWC, "TInput1", NCHW);
op_builder.Input("TInput1"); op_builder.Input("TInput1");
} else { } else {
op_builder.Input("Input1"); op_builder.Input("Input1");
...@@ -106,7 +142,7 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, ...@@ -106,7 +142,7 @@ void SimpleTensorEltwise(const kernels::EltwiseType type,
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC);
} else { } else {
BufferToImage<D, T>(&net, "Input0", "InputImg0", BufferToImage<D, T>(&net, "Input0", "InputImg0",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
...@@ -123,42 +159,45 @@ void SimpleTensorEltwise(const kernels::EltwiseType type, ...@@ -123,42 +159,45 @@ void SimpleTensorEltwise(const kernels::EltwiseType type,
// Run // Run
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "Output", ImageToBuffer<D, DstType>(&net, "OutputImg", "Output",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
} }
std::vector<index_t> output_shape = shape0; std::vector<index_t> output_shape = shape0;
if (input0.size() < input1.size()) { if (input0.size() < input1.size()) {
output_shape = shape1; output_shape = shape1;
} }
auto expected = CreateTensor<float>(output_shape, output); auto expected = CreateTensor<DstType>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<DstType>(*expected, *net.GetOutput("Output"), 1e-5);
} }
template<DeviceType D, typename T> template <DeviceType D, typename T, typename DstType>
void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type, void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type,
const std::vector<index_t> &shape0, const std::vector<index_t> &shape0,
const std::vector<float> &input0, const std::vector<T> &input0,
const std::vector<index_t> &shape1, const std::vector<index_t> &shape1,
const std::vector<float> &input1, const std::vector<T> &input1,
const std::vector<index_t> &output_shape, const std::vector<index_t> &output_shape,
const std::vector<float> &output, const std::vector<DstType> &output,
const std::vector<float> &coeff = {}) { const std::vector<float> &coeff = {}) {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddInputFromArray<D, float>("Input0", shape0, input0); net.AddInputFromArray<D, T>("Input0", shape0, input0);
net.AddInputFromArray<D, float>("Input1", shape1, input1); net.AddInputFromArray<D, T>("Input1", shape1, input1);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
auto op_builder = OpDefBuilder("Eltwise", "EltwiseTest") auto op_builder =
.Input("Input0") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input1") .AddIntArg("T", DataTypeToEnum<T>::v())
.AddIntArg("type", static_cast<int>(type)) .Input("Input0")
.AddFloatsArg("coeff", coeff) .Input("Input1")
.Output("Output"); .AddIntArg("type", static_cast<int>(type))
.AddFloatsArg("coeff", coeff)
.OutputType({kernels::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("Output");
op_builder.Finalize(net.NewOperatorDef()); op_builder.Finalize(net.NewOperatorDef());
// Run // Run
...@@ -167,214 +206,248 @@ void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type, ...@@ -167,214 +206,248 @@ void TensorGeneralBroadcastEltwise(const kernels::EltwiseType type,
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
auto expected = CreateTensor<float>(output_shape, output); auto expected = CreateTensor<DstType>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<DstType>(*expected, *net.GetOutput("Output"), 1e-5);
} }
} // namespace } // namespace
TEST_F(EltwiseOpTest, CPUSimpleScalarScalar) {
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, 1, 2, 3);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, 1, 2, -1);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::PROD, 1, 2, 2);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, 1, 2, 0.5);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::MIN, 1, 2, 1);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::MAX, 1, 2, 2);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::NEG, 1, 2, -1);
SimpleScalarScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::ABS, -1, 3, 1);
SimpleScalarScalar<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, 1, 3, 0);
SimpleScalarScalar<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, 3, 3, 1);
}
TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) { TEST_F(EltwiseOpTest, CPUSimpleTensorScalar) {
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::SUM, SimpleTensorScalar<DeviceType::CPU, float, float>(kernels::EltwiseType::SUM,
{1, 1, 1, 1}, {1}, 1, {2}); {1, 1, 1, 1}, {1}, 1, {2});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::SUB, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {0, 1, 2, 3, 4, 5}); {0, 1, 2, 3, 4, 5});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::PROD, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2,
2, {2, 4, 6, 8, 10, 12}); {2, 4, 6, 8, 10, 12});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::DIV, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, kernels::EltwiseType::DIV, {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, 2,
2, {1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::MIN, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {1, 1, 1, 1, 1, 1}); {1, 1, 1, 1, 1, 1});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::MAX, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
3, {3, 3, 3, 4, 5, 6}); {3, 3, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::NEG, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::NEG, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
3, {-1, -2, -3, -4, -5, -6}); {-1, -2, -3, -4, -5, -6});
SimpleTensorScalar<DeviceType::CPU, float>( SimpleTensorScalar<DeviceType::CPU, float, float>(
kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3, kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3,
{1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::CPU, float>(kernels::EltwiseType::SQR_DIFF, SimpleTensorScalar<DeviceType::CPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {0, 1, 4, 9, 16, 25}); {0, 1, 4, 9, 16, 25});
SimpleTensorScalar<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
{0, 0, 1, 0, 0, 0});
} }
TEST_F(EltwiseOpTest, GPUSimpleTensorScalar) { TEST_F(EltwiseOpTest, GPUSimpleTensorScalar) {
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::SUM, SimpleTensorScalar<DeviceType::GPU, float, float>(kernels::EltwiseType::SUM,
{1, 1, 1, 1}, {1}, 1, {2}); {1, 1, 1, 1}, {1}, 1, {2});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::SUB, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {0, 1, 2, 3, 4, 5}); {0, 1, 2, 3, 4, 5});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::PROD, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 2,
2, {2, 4, 6, 8, 10, 12}); {2, 4, 6, 8, 10, 12});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::DIV, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, kernels::EltwiseType::DIV, {1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}, 2,
2, {1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::MIN, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {1, 1, 1, 1, 1, 1}); {1, 1, 1, 1, 1, 1});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::MAX, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
3, {3, 3, 3, 4, 5, 6}); {3, 3, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::NEG, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::NEG, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 3,
3, {-1, -2, -3, -4, -5, -6}); {-1, -2, -3, -4, -5, -6});
SimpleTensorScalar<DeviceType::GPU, float>( SimpleTensorScalar<DeviceType::GPU, float, float>(
kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3, kernels::EltwiseType::ABS, {1, 1, 2, 3}, {-1, -2, -3, -4, -5, -6}, 3,
{1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
SimpleTensorScalar<DeviceType::GPU, float>(kernels::EltwiseType::SQR_DIFF, SimpleTensorScalar<DeviceType::GPU, float, float>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, 1,
1, {0, 1, 4, 9, 16, 25}); {0, 1, 4, 9, 16, 25});
} }
TEST_F(EltwiseOpTest, CPUSimpleTensorVector) { TEST_F(EltwiseOpTest, CPUSimpleTensorVector) {
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3},
{1, 2, 3}, {2, 4, 6, 5, 7, 9}); {1, 2, 3}, {2, 4, 6, 5, 7, 9});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3}, kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); {1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5},
{1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); {0, 0, 0, 0, 0, 25, 25, 25, 25, 25});
SimpleTensorEltwise<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 1, 3}, {1, 2, 3}, {1, 1, 1, 0, 0, 0});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {3},
{1, 2, 3}, {2, 4, 6, 5, 7, 9}); {1, 2, 3}, {2, 4, 6, 5, 7, 9});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); {5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::SUB, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::PROD, {3}, {1, 2, 3}, {1, 2, 1, 3}, kernels::EltwiseType::PROD, {3}, {1, 2, 3}, {1, 2, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); {5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, kernels::EltwiseType::DIV, {5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5},
{1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MIN, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::MIN, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); {5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {5}, {1, 2, 3, 4, 5}, kernels::EltwiseType::SQR_DIFF, {5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25});
{0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); SimpleTensorEltwise<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {3},
{1, 2, 3}, {1, 1, 1, 0, 0, 0});
} }
TEST_F(EltwiseOpTest, GPUSimpleTensorVector) { TEST_F(EltwiseOpTest, GPUSimpleTensorVector) {
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 3},
{1, 2, 3}, {2, 4, 6, 5, 7, 9}); {1, 2, 3}, {2, 4, 6, 5, 7, 9});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::SUB, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5}); {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0, 5, 5, 5, 5, 5});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, -5, -5, -5, -5, -5});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3}, kernels::EltwiseType::PROD, {1, 1, 1, 3}, {1, 2, 3}, {1, 2, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18}); {1, 2, 3, 4, 5, 6}, {1, 4, 9, 4, 10, 18});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::DIV, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2}); {1, 1, 1, 5}, {1, 1, 1, 1, 5}, {1, 2, 3, 4, 1, 6, 7, 8, 9, 2});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5}, kernels::EltwiseType::DIV, {1, 1, 1, 5}, {1, 1, 1, 2, 4}, {1, 2, 1, 5},
{1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4}); {1, 1, 1, 2, 2, 1, 1, 1, 1, 1}, {1, 1, 1, 1, 2, 1, 1, 1, 2, 4});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5}, kernels::EltwiseType::MIN, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, kernels::EltwiseType::SQR_DIFF, {1, 1, 1, 5}, {1, 2, 3, 4, 5},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); {0, 0, 0, 0, 0, 25, 25, 25, 25, 25});
} }
TEST_F(EltwiseOpTest, CPUSimpleTensorTensor) { TEST_F(EltwiseOpTest, CPUSimpleTensorTensor) {
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1}); {1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5},
{1, 2, 3, 4, 5}, {0, 0, 0, 0, 0}); {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6},
{1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1}); {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); {1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
SimpleTensorEltwise<DeviceType::CPU, float>( SimpleTensorEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5}, kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25});
SimpleTensorEltwise<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1});
} }
TEST_F(EltwiseOpTest, GPUSimpleTensorTensor) { TEST_F(EltwiseOpTest, GPUSimpleTensorTensor) {
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12}); {1, 2, 3, 4, 5, 6}, {2, 4, 6, 8, 10, 12});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1}); {1, 2, 3, 4, 5, 6}, {0.2, 0.4, 0.6, 0.8, 1, 1.2}, {0.1, 0.1});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, kernels::EltwiseType::SUB, {1, 1, 1, 5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5},
{1, 2, 3, 4, 5}, {0, 0, 0, 0, 0}); {1, 2, 3, 4, 5}, {0, 0, 0, 0, 0});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::PROD, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6},
{1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36}); {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 4, 9, 16, 25, 36});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3}, kernels::EltwiseType::DIV, {1, 2, 1, 3}, {1, 2, 3, 4, 5, 6}, {1, 2, 1, 3},
{1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1}); {1, 2, 3, 4, 5, 6}, {1, 1, 1, 1, 1, 1});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, kernels::EltwiseType::MIN, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}); {1, 2, 3, 4, 5, 1, 2, 3, 4, 5});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, kernels::EltwiseType::MAX, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10},
{1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10});
SimpleTensorEltwise<DeviceType::GPU, float>( SimpleTensorEltwise<DeviceType::GPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5}, kernels::EltwiseType::SQR_DIFF, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5}, {1, 2, 3, 4, 5, 1, 2, 3, 4, 5}, {1, 2, 1, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25}); {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {0, 0, 0, 0, 0, 25, 25, 25, 25, 25});
...@@ -595,27 +668,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) { ...@@ -595,27 +668,30 @@ TEST_F(EltwiseOpTest, RandomTensorTensorHalf) {
} }
TEST_F(EltwiseOpTest, TensorGeneralBroadcast) { TEST_F(EltwiseOpTest, TensorGeneralBroadcast) {
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, kernels::EltwiseType::SUM, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {2, 3, 4, 6, 7, 8}); {1, 2}, {1, 1, 2, 3}, {2, 3, 4, 6, 7, 8});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, kernels::EltwiseType::SUB, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {0, 1, 2, 2, 3, 4}); {1, 2}, {1, 1, 2, 3}, {0, 1, 2, 2, 3, 4});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::PROD, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 8, 10, 12}); {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 8, 10, 12});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, kernels::EltwiseType::DIV, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 2, 3, 2, 2.5, 3}); {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 2, 2.5, 3});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, kernels::EltwiseType::MIN, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 1, 1, 2, 2, 2}); {1, 2}, {1, 1, 2, 3}, {1, 1, 1, 2, 2, 2});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1}, kernels::EltwiseType::MAX, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1, 2, 1},
{1, 2}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); {1, 2}, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6});
TensorGeneralBroadcastEltwise<DeviceType::CPU, float>( TensorGeneralBroadcastEltwise<DeviceType::CPU, float, float>(
kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, kernels::EltwiseType::SQR_DIFF, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 4, 4, 9, 16}); {1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {0, 1, 4, 4, 9, 16});
TensorGeneralBroadcastEltwise<DeviceType::CPU, int32_t, int32_t>(
kernels::EltwiseType::EQUAL, {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6},
{1, 1, 2, 1}, {1, 2}, {1, 1, 2, 3}, {1, 0, 0, 0, 0, 0});
} }
} // namespace test } // namespace test
......
...@@ -23,6 +23,11 @@ void Register_Identity(OperatorRegistry *op_registry) { ...@@ -23,6 +23,11 @@ void Register_Identity(OperatorRegistry *op_registry) {
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
IdentityOp<DeviceType::CPU, float>); IdentityOp<DeviceType::CPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
.Device(DeviceType::CPU)
.TypeConstraint<int32_t>("T")
.Build(),
IdentityOp<DeviceType::CPU, int32_t>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Identity")
......
...@@ -38,12 +38,12 @@ void TestStridedSlice(const std::vector<index_t> &input_shape, ...@@ -38,12 +38,12 @@ void TestStridedSlice(const std::vector<index_t> &input_shape,
OpsTestNet net; OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input); net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"BeginIndices", {static_cast<int32_t>(input_shape.size())}, "BeginIndices", {static_cast<int32_t>(begin_indices.size())},
begin_indices); begin_indices);
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"EndIndices", {static_cast<int32_t>(input_shape.size())}, end_indices); "EndIndices", {static_cast<int32_t>(end_indices.size())}, end_indices);
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(input_shape.size())}, strides); "Strides", {static_cast<int32_t>(strides.size())}, strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input") .Input("Input")
...@@ -130,6 +130,8 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank1) { ...@@ -130,6 +130,8 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank1) {
TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0, TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 1}, 0, 0, 0,
0, 0, {2, 3}, {1, 2, 3, 4, 5, 6}); 0, 0, {2, 3}, {1, 2, 3, 4, 5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0}, {2}, {1}, 0, 0, 0,
0, 0, {2, 3}, {1, 2, 3, 4, 5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0, TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {1, 1}, {2, 3}, {1, 1}, 0, 0, 0,
0, 0, {1, 2}, {5, 6}); 0, 0, {1, 2}, {5, 6});
TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0, TestStridedSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {1, 2}, 0, 0, 0,
......
...@@ -66,11 +66,13 @@ class EltwiseType(Enum): ...@@ -66,11 +66,13 @@ class EltwiseType(Enum):
ABS = 7 ABS = 7
SQR_DIFF = 8 SQR_DIFF = 8
POW = 9 POW = 9
EQUAL = 10
MaceSupportedOps = [ MaceSupportedOps = [
'Activation', 'Activation',
'AddN', 'AddN',
'ArgMax',
'BatchNorm', 'BatchNorm',
'BatchToSpaceND', 'BatchToSpaceND',
'BiasAdd', 'BiasAdd',
......
...@@ -62,6 +62,7 @@ TFSupportedOps = [ ...@@ -62,6 +62,7 @@ TFSupportedOps = [
'Square', 'Square',
'SquaredDifference', 'SquaredDifference',
'Rsqrt', 'Rsqrt',
'Equal',
'Relu', 'Relu',
'Relu6', 'Relu6',
'Tanh', 'Tanh',
...@@ -93,6 +94,7 @@ TFSupportedOps = [ ...@@ -93,6 +94,7 @@ TFSupportedOps = [
'Stack', 'Stack',
'Pack', 'Pack',
'Cast', 'Cast',
'ArgMax',
] ]
TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str) TFOpType = Enum('TFOpType', [(op, op) for op in TFSupportedOps], type=str)
...@@ -125,7 +127,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -125,7 +127,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.RealDiv.name: EltwiseType.DIV, TFOpType.RealDiv.name: EltwiseType.DIV,
TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
TFOpType.Square.name: EltwiseType.POW, TFOpType.Square.name: EltwiseType.POW,
TFOpType.Rsqrt.name: EltwiseType.POW TFOpType.Rsqrt.name: EltwiseType.POW,
TFOpType.Equal.name: EltwiseType.EQUAL,
} }
activation_type = { activation_type = {
TFOpType.Relu.name: ActivationType.RELU, TFOpType.Relu.name: ActivationType.RELU,
...@@ -153,6 +156,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -153,6 +156,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.SquaredDifference.name: self.convert_elementwise, TFOpType.SquaredDifference.name: self.convert_elementwise,
TFOpType.Square.name: self.convert_elementwise, TFOpType.Square.name: self.convert_elementwise,
TFOpType.Rsqrt.name: self.convert_elementwise, TFOpType.Rsqrt.name: self.convert_elementwise,
TFOpType.Equal.name: self.convert_elementwise,
TFOpType.Relu.name: self.convert_activation, TFOpType.Relu.name: self.convert_activation,
TFOpType.Relu6.name: self.convert_activation, TFOpType.Relu6.name: self.convert_activation,
TFOpType.Tanh.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation,
...@@ -183,7 +187,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -183,7 +187,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Slice.name: self.convert_slice, TFOpType.Slice.name: self.convert_slice,
TFOpType.Pack.name: self.convert_stack, TFOpType.Pack.name: self.convert_stack,
TFOpType.Stack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack,
TFOpType.Cast.name: self.convert_cast TFOpType.Cast.name: self.convert_cast,
TFOpType.ArgMax.name: self.convert_argmax,
} }
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
...@@ -376,18 +381,29 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -376,18 +381,29 @@ class TensorflowConverter(base_converter.ConverterInterface):
if type_arg.i != EltwiseType.NEG.value \ if type_arg.i != EltwiseType.NEG.value \
and type_arg.i != EltwiseType.ABS.value: and type_arg.i != EltwiseType.ABS.value:
if len(tf_op.inputs[0].shape) == 0: try:
value_arg = op.arg.add() def is_commutative(eltwise_type):
value_arg.name = MaceKeyword.mace_value_str return EltwiseType(eltwise_type) in [
value_arg.f = tf_op.inputs[0].eval().astype(np.float32) EltwiseType.SUM, EltwiseType.PROD,
self._skip_tensor.add(tf_op.inputs[0].name) EltwiseType.MAX, EltwiseType.MIN]
del op.input[0]
elif len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0: if len(tf_op.inputs) > 1 and len(tf_op.inputs[1].shape) == 0:
value_arg = op.arg.add() scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg.name = MaceKeyword.mace_value_str value_arg = op.arg.add()
value_arg.f = tf_op.inputs[1].eval().astype(np.float32) value_arg.name = MaceKeyword.mace_value_str
self._skip_tensor.add(tf_op.inputs[1].name) value_arg.f = scalar
del op.input[1] self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1]
elif len(tf_op.inputs[0].shape) == 0 and \
is_commutative(type_arg.i):
scalar = tf_op.inputs[0].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str
value_arg.f = scalar
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
except tf.errors.InvalidArgumentError:
pass
def convert_biasadd(self, tf_op): def convert_biasadd(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
...@@ -550,7 +566,13 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -550,7 +566,13 @@ class TensorflowConverter(base_converter.ConverterInterface):
transpose_a_arg.name = MaceKeyword.mace_transpose_a_str transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
transpose_a_arg.i = int(adj_x) transpose_a_arg.i = int(adj_x)
except ValueError: except ValueError:
pass try:
transpose_a = tf_op.get_attr('transpose_a')
transpose_a_arg = op.arg.add()
transpose_a_arg.name = MaceKeyword.mace_transpose_a_str
transpose_a_arg.i = int(transpose_a)
except ValueError:
pass
try: try:
adj_y = tf_op.get_attr('adj_y') adj_y = tf_op.get_attr('adj_y')
...@@ -558,7 +580,13 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -558,7 +580,13 @@ class TensorflowConverter(base_converter.ConverterInterface):
transpose_b_arg.name = MaceKeyword.mace_transpose_b_str transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
transpose_b_arg.i = int(adj_y) transpose_b_arg.i = int(adj_y)
except ValueError: except ValueError:
pass try:
transpose_b = tf_op.get_attr('transpose_b')
transpose_b_arg = op.arg.add()
transpose_b_arg.name = MaceKeyword.mace_transpose_b_str
transpose_b_arg.i = int(transpose_b)
except ValueError:
pass
def convert_shape(self, tf_op): def convert_shape(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
...@@ -689,14 +717,18 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -689,14 +717,18 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Cast.name op.type = MaceOp.Cast.name
data_type_arg = ConverterUtil.get_arg(op, 'T')
try: try:
dtype = tf_op.get_attr('DstT') dtype = tf_op.get_attr('DstT')
if dtype == tf.int32: if dtype == tf.int32:
data_type_arg.i = mace_pb2.DT_INT32 op.output_type.extend([mace_pb2.DT_INT32])
elif dtype == tf.float32: elif dtype == tf.float32:
data_type_arg.i = self._option.data_type op.output_type.extend([self._option.data_type])
else: else:
mace_check(False, "data type %s not supported" % dtype) mace_check(False, "data type %s not supported" % dtype)
except ValueError: except ValueError:
data_type_arg.i = self._option.data_type op.output_type.extend([self._option.data_type])
def convert_argmax(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ArgMax.name
op.output_type.extend([mace_pb2.DT_INT32])
...@@ -48,6 +48,10 @@ namespace str_util { ...@@ -48,6 +48,10 @@ namespace str_util {
std::vector<std::string> Split(const std::string &str, char delims) { std::vector<std::string> Split(const std::string &str, char delims) {
std::vector<std::string> result; std::vector<std::string> result;
if (str.empty()) {
result.push_back("");
return result;
}
std::string tmp = str; std::string tmp = str;
while (!tmp.empty()) { while (!tmp.empty()) {
size_t next_offset = tmp.find(delims); size_t next_offset = tmp.find(delims);
......
...@@ -773,11 +773,17 @@ def tuning_run(abi, ...@@ -773,11 +773,17 @@ def tuning_run(abi,
(phone_data_dir, os.path.basename(opencl_binary_file)), (phone_data_dir, os.path.basename(opencl_binary_file)),
]) ])
adb_cmd = ' '.join(adb_cmd) adb_cmd = ' '.join(adb_cmd)
adb_cmd_file = "%s/%s" % (phone_data_dir, 'cmd_file')
with open('/tmp/mace_cmd_file', 'w') as cmd_file:
cmd_file.write(adb_cmd)
adb_push('/tmp/mace_cmd_file', adb_cmd_file, serialno)
sh.adb( sh.adb(
"-s", "-s",
serialno, serialno,
"shell", "shell",
adb_cmd, "sh",
adb_cmd_file,
_tty_in=True, _tty_in=True,
_out=process_output, _out=process_output,
_err_to_out=True) _err_to_out=True)
...@@ -1159,10 +1165,7 @@ def benchmark_model(abi, ...@@ -1159,10 +1165,7 @@ def benchmark_model(abi,
phone_data_dir, phone_data_dir,
serialno) serialno)
sh.adb( adb_cmd = [
"-s",
serialno,
"shell",
"LD_LIBRARY_PATH=%s" % phone_data_dir, "LD_LIBRARY_PATH=%s" % phone_data_dir,
"MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level, "MACE_CPP_MIN_VLOG_LEVEL=%s" % vlog_level,
"MACE_RUN_PARAMETER_PATH=%s/mace_run.config" % "MACE_RUN_PARAMETER_PATH=%s/mace_run.config" %
...@@ -1185,6 +1188,19 @@ def benchmark_model(abi, ...@@ -1185,6 +1188,19 @@ def benchmark_model(abi,
"--model_file=%s" % mace_model_phone_path, "--model_file=%s" % mace_model_phone_path,
"--opencl_binary_file=%s/%s" % "--opencl_binary_file=%s/%s" %
(phone_data_dir, os.path.basename(opencl_binary_file)), (phone_data_dir, os.path.basename(opencl_binary_file)),
]
adb_cmd = ' '.join(adb_cmd)
adb_cmd_file = "%s/%s" % (phone_data_dir, 'cmd_file')
with open('/tmp/mace_cmd_file', 'w') as cmd_file:
cmd_file.write(adb_cmd)
adb_push('/tmp/mace_cmd_file', adb_cmd_file, serialno)
sh.adb(
"-s",
serialno,
"shell",
"sh",
adb_cmd_file,
_fg=True) _fg=True)
print("Benchmark done!\n") print("Benchmark done!\n")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册