提交 e4ac3908 编写于 作者: L liutuo

add onnx

add onnx_optimizer tool and update onnx info in docs
上级 09a7b52d
......@@ -82,7 +82,8 @@ the following projects during the development:
[Caffe](https://github.com/BVLC/caffe),
[SNPE](https://developer.qualcomm.com/software/snapdragon-neural-processing-engine-ai),
[ARM ComputeLibrary](https://github.com/ARM-software/ComputeLibrary),
[ncnn](https://github.com/Tencent/ncnn) and many others: we learned many best
[ncnn](https://github.com/Tencent/ncnn),
[ONNX](https://github.com/onnx/onnx) and many others: we learned many best
practices from these projects.
Finally, we also thank the Qualcomm, Pinecone and MediaTek engineering teams for
......
......@@ -64,6 +64,9 @@ Optional dependencies
* - FileLock
- pip install -I filelock==3.0.0
- Required by run on Android
* - ONNX
- pip install onnx
- Required by ONNX model
.. note::
......
......@@ -72,3 +72,9 @@ Install Caffe (Optional)
-------------------------
Please follow the installation instruction of `Caffe <http://caffe.berkeleyvision.org/installation.html>`__.
Install ONNX (Optional)
-------------------------
Please follow the installation instruction of `ONNX <https://github.com/onnx/onnx#source>`__.
......@@ -18,8 +18,7 @@ MACE Model
~~~~~~~~~~
MACE defines a customized model format which is similar to
Caffe2. The MACE model can be converted from exported models by TensorFlow
and Caffe.
Caffe2. The MACE model can be converted from exported models by TensorFlow, Caffe or ONNX Model.
MACE Interpreter
~~~~~~~~~~~~~~~~~
......@@ -50,7 +49,7 @@ Build MACE dynamic or static libraries.
3. Convert model
~~~~~~~~~~~~~~~~~~
Convert TensorFlow or Caffe model to MACE model.
Convert TensorFlow, Caffe or ONNX model to MACE model.
4.1. Deploy
~~~~~~~~~~~~~~~~~~
......@@ -86,7 +85,7 @@ MACE覆盖了常见的移动端计算设备(CPU,GPU和DSP),并且提供
MACE Model
~~~~~~~~~~~~~~~~~~
MACE定义了自有的模型格式(类似于Caffe2),通过MACE提供的工具可以将Caffe和TensorFlow的模型
MACE定义了自有的模型格式(类似于Caffe2),通过MACE提供的工具可以将Caffe/TensorFlow/ONNX格式的模型
转为MACE模型。
MACE Interpreter
......@@ -118,7 +117,7 @@ CPU/GPU/DSP Runtime对应于各个计算设备的算子实现。
3. 转换模型
~~~~~~~~~~~~~~~~~~
将TensorFlow 或者 Caffe的模型转为MACE的模型。
将TensorFlow或者Caffe或者ONNX的模型转为MACE的模型。
4.1. 部署
~~~~~~~~~~~~~~~~~~
......
......@@ -78,6 +78,8 @@ in one deployment file.
- [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used.
* - validation_threshold
- [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0.
* - backend
- The onnx backend framework for validation, could be [tensorflow, caffe2, pytorch], default is tensorflow.
* - runtime
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type
......
......@@ -114,6 +114,19 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp
# Upgrade caffemodel
$CAFFE_ROOT/build/tools/upgrade_net_proto_binary MODEL.caffemodel MODEL.new.caffemodel
- ONNX
Prepare your ONNX model.onnx file.
Use `ONNX Optimizer Tool <https://github.com/XiaoMi/mace/tree/master/tools/onnx_optimizer.py>`__ to optimize your model for inference.
This tool will improve the efficiency of inference like the `Graph Transform Tool <https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md>`__
in TensorFlow.
.. code:: bash
# Optimize your model
$python MACE_ROOT/tools/onnx_optimizer.py model.onnx model_opt.onnx
===========================================
2. Create a deployment file for your model
......@@ -137,6 +150,12 @@ Modify one of them and use it for your own case.
.. literalinclude:: models/demo_models_caffe.yml
:language: yaml
- ONNX
.. literalinclude:: models/demo_models_onnx.yml
:language: yaml
More details about model deployment file are in :doc:`advanced_usage`.
======================
......
# The name of library
library_name: mobilenet
target_abis: [arm64-v8a]
model_graph_format: file
model_data_format: file
models:
mobilenet_v1: # model tag, which will be used in model loading and must be specific.
platform: onnx
# path to your onnx model file. Support local path, http:// and https://
model_file_path: https://cnbj1.fds.api.xiaomi.com/mace/miai-models/mobilenet-v1/mobilenet-v1-1.0.pb
# sha256_checksum of your model's onnx file.
# use this command to get the sha256_checksum: sha256sum path/to/your/pb/file
model_sha256_checksum: 71b10f540ece33c49a7b51f5d4095fc9bd78ce46ebf0300487b2ee23d71294e6
# define your model's interface
# if there multiple inputs or outputs, write like blow:
# subgraphs:
# - input_tensors:
# - input0
# - input1
# input_shapes:
# - 1,224,224,3
# - 1,224,224,3
# output_tensors:
# - output0
# - output1
# output_shapes:
# - 1,1001
# - 1,1001
subgraphs:
- input_tensors:
- input
input_shapes:
- 1,224,224,3
output_tensors:
- MobilenetV1/Predictions/Reshape_1
output_shapes:
- 1,1001
# onnx backend framwork for validation. Suppport pytorch/caffe/tensorflow. Default is tensorflow.
backend: tensorflow
# cpu, gpu or cpu+gpu
runtime: cpu+gpu
winograd: 0
\ No newline at end of file
......@@ -32,7 +32,8 @@ enum ActivationType {
RELUX = 2,
PRELU = 3,
TANH = 4,
SIGMOID = 5
SIGMOID = 5,
LEAKYRELU = 6,
};
inline ActivationType StringToActivationType(const std::string type) {
......@@ -48,6 +49,8 @@ inline ActivationType StringToActivationType(const std::string type) {
return ActivationType::SIGMOID;
} else if (type == "NOOP") {
return ActivationType::NOOP;
} else if (type == "LEAKYRELU") {
return ActivationType ::LEAKYRELU;
} else {
LOG(FATAL) << "Unknown activation type: " << type;
}
......@@ -90,6 +93,13 @@ void DoActivation(const T *input_ptr,
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
break;
case LEAKYRELU:
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max(input_ptr[i],
static_cast<T>(0)) * relux_max_limit;
}
break;
default:
LOG(FATAL) << "Unknown activation type: " << type;
}
......@@ -122,6 +132,9 @@ inline void DoActivation(const float *input_ptr,
output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i]));
}
break;
case LEAKYRELU:
LeakyReluNeon(input_ptr, relux_max_limit, size, output_ptr);
break;
default:
LOG(FATAL) << "Unknown activation type: " << type;
}
......
......@@ -27,18 +27,29 @@ template <DeviceType D, class T>
class ArgMaxOp : public Operation {
public:
explicit ArgMaxOp(OpConstructContext *context)
: Operation(context) {}
: Operation(context),
axis_(Operation::GetOptionalArg<int>("axis", 0)),
keep_dims_(Operation::GetOptionalArg<bool>("keepdims", true)),
argmin_(Operation::GetOptionalArg<bool>("argmin", false)) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
const Tensor *input = this->Input(0);
const Tensor *axis = this->Input(1);
const Tensor *axis = this->InputSize() == 2 ?
this->Input(1) : nullptr;
Tensor *output = this->Output(0);
MACE_CHECK(keep_dims_, "Mace only supports keep_dims ArgMax.");
MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar");
MACE_CHECK(axis->dim_size() == 0, "Mace argmax only supports scalar axis");
Tensor::MappingGuard axis_guard(axis);
int axis_value = axis->data<int32_t>()[0];
int axis_value = 0;
if (axis != nullptr) {
MACE_CHECK(axis->dim_size() == 0,
"Mace argmax only supports scalar axis");
Tensor::MappingGuard axis_guard(axis);
axis_value = axis->data<int32_t>()[0];
} else {
axis_value = axis_;
}
if (axis_value < 0) {
axis_value += input->dim_size();
}
......@@ -59,22 +70,43 @@ class ArgMaxOp : public Operation {
index_t outer_size = output->size();
index_t inner_size = input->dim(axis_value);
if (argmin_) {
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T max_value = std::numeric_limits<T>::lowest();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
if (input_ptr[j] > max_value) {
max_value = input_ptr[j];
idx = j;
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T min_value = std::numeric_limits<T>::max();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
if (input_ptr[j] < min_value) {
min_value = input_ptr[j];
idx = j;
}
}
output_data[i] = idx;
}
} else {
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < outer_size; ++i) {
int idx = 0;
T max_value = std::numeric_limits<T>::lowest();
const T *input_ptr = input_data + i * inner_size;
for (index_t j = 0; j < inner_size; ++j) {
if (input_ptr[j] > max_value) {
max_value = input_ptr[j];
idx = j;
}
}
output_data[i] = idx;
}
output_data[i] = idx;
}
return MaceStatus::MACE_SUCCESS;
}
protected:
const int axis_;
bool keep_dims_;
bool argmin_;
};
......
......@@ -67,5 +67,29 @@ void ReluxNeon(const float *input, const float limit,
#endif
}
void LeakyReluNeon(const float *input, const float alpha,
const index_t size, float *output) {
#if defined(MACE_ENABLE_NEON)
float32x4_t vzero = vdupq_n_f32(0.f);
float32x4_t valpha = vdupq_n_f32(alpha);
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i <= size - 4; i += 4) {
float32x4_t v = vld1q_f32(input + i);
v = vmaxq_f32(v, vzero);
v = vmulq_f32(v, valpha);
vst1q_f32(output + i, v);
}
// remain
for (index_t i = (size >> 2) << 2; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha;
}
#else
#pragma omp parallel for schedule(runtime)
for (index_t i = 0; i < size; ++i) {
output[i] = std::max(input[i], 0.f) * alpha;
}
#endif
}
} // namespace ops
} // namespace mace
......@@ -25,6 +25,9 @@ void ReluNeon(const float *input, const index_t size, float *output);
void ReluxNeon(const float *input, const float limit,
const index_t size, float *output);
void LeakyReluNeon(const float *input, const float alpha,
const index_t size, float *output);
} // namespace ops
} // namespace mace
......
......@@ -43,6 +43,7 @@ class PoolingKernel : public OpenCLPoolingKernel {
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
const RoundType round_type,
Tensor *output) override;
private:
......@@ -62,6 +63,7 @@ MaceStatus PoolingKernel<T>::Compute(
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
const RoundType round_type,
Tensor *output) {
MACE_CHECK(dilations[0] == 1 && dilations[1] == 1)
<< "Pooling opencl kernel not support dilation yet";
......@@ -82,7 +84,7 @@ MaceStatus PoolingKernel<T>::Compute(
} else {
paddings = padding_data;
CalcOutputSize(input->shape().data(), filter_shape.data(),
padding_data.data(), dilations, strides, RoundType::CEIL,
padding_data.data(), dilations, strides, round_type,
output_shape.data());
}
......
......@@ -102,6 +102,9 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
#endif
#ifdef USE_SIGMOID
out = do_sigmoid(in);
#endif
#ifdef USE_LEAKYRELU
out = fmax(in, (DATA_TYPE)0) * relux_max_limit;
#endif
return out;
}
......
#include <common.h>
__kernel void reduce_mean(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__local float4 *group_sum,
__private const int group_size,
__private const int partial_len,
__private const int remain_index,
__private const int batch,
__private const int in_height,
__private const int in_width,
__private const float image_size_reciprocal,
__private const int channel_blocks,
__write_only image2d_t output) {
__kernel void reduce(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__local float4 *group_sum,
__private const int group_size,
__private const int partial_len,
__private const int remain_index,
__private const int batch,
__private const int in_height,
__private const int in_width,
__private const float image_size_reciprocal,
__private const int channel_blocks,
__write_only image2d_t output) {
const int i = get_local_id(0);
const int j = get_local_id(1);
const int k = get_global_id(2);
......@@ -22,12 +22,22 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS
return;
#endif
const int dim0_size = get_local_size(0);
float4 tmp = (float4){0, 0, 0, 0};
const int index = mad24(j, dim0_size, i);
const int b = k / channel_blocks;
const int ch = mad24(b, -channel_blocks, k);
DATA_TYPE4 in;
#if REDUCE_TYPE == 1
float4 tmp = (float4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT};
#elif REDUCE_TYPE == 2
float4 tmp = (float4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT};
#elif REDUCE_TYPE == 3
float4 tmp = (float4){1, 1, 1, 1};
#else
float4 tmp = (float4){0, 0, 0, 0};
#endif
const int valid_part_len = select(partial_len,
partial_len - 1,
remain_index > 0 && index >= remain_index);
......@@ -43,19 +53,51 @@ __kernel void reduce_mean(OUT_OF_RANGE_PARAMS
int pos_x = mad24(ch, in_width, w_id);
int pos_y = mad24(b, in_height, h_id);
in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
// MIN
#if REDUCE_TYPE == 1
tmp = fmin(tmp, in);
// MAX
#elif REDUCE_TYPE == 2
tmp = fmax(tmp, in);
// PROD
#elif REDUCE_TYPE == 3
tmp = tmp * in;
// MEAN
#else
tmp = tmp + in;
#endif
}
group_sum[index] = tmp * image_size_reciprocal;
#if REDUCE_TYPE == 0
tmp = tmp * image_size_reciprocal;
#endif
group_sum[index] = tmp;
#ifdef NON_QUALCOMM_ADRENO
barrier(CLK_LOCAL_MEM_FENCE);
#endif
if (i == 0 && j == 0) {
#if REDUCE_TYPE == 1
DATA_TYPE4 out = (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT};
#elif REDUCE_TYPE == 2
DATA_TYPE4 out = (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT};
#elif REDUCE_TYPE == 3
DATA_TYPE4 out = (DATA_TYPE4){1, 1, 1, 1};
#else
DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0};
#endif
#pragma unroll
for (int l = 0; l < group_size; ++l) {
#if REDUCE_TYPE == 1
out = fmin(out, group_sum[l]);
#elif REDUCE_TYPE == 2
out = fmax(out, group_sum[l]);
#elif REDUCE_TYPE == 3
out = out * group_sum[l];
#else
out = out + group_sum[l];
#endif
}
WRITE_IMAGET(output, (int2)(ch, b), out);
}
......
......@@ -99,6 +99,10 @@ MaceStatus ActivationKernel<T>::Compute(
tuning_key_prefix_ = "sigmoid_opencl_kernel";
built_options.emplace("-DUSE_SIGMOID");
break;
case LEAKYRELU:
tuning_key_prefix_ = "leakyrelu_opencl_kernel";
built_options.emplace("-DUSE_LEAKYRELU");
break;
default:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
......
......@@ -69,6 +69,7 @@ class PoolingKernel : public OpenCLPoolingKernel {
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
const RoundType round_type,
Tensor *output) override;
private:
......@@ -87,6 +88,7 @@ MaceStatus PoolingKernel<T>::Compute(
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
const RoundType round_type,
Tensor *output) {
MACE_CHECK(dilations[0] == 1 && dilations[1] == 1)
<< "Pooling opencl kernel not support dilation yet";
......@@ -103,7 +105,7 @@ MaceStatus PoolingKernel<T>::Compute(
} else {
paddings = padding_data;
CalcOutputSize(input->shape().data(), filter_shape.data(),
padding_data.data(), dilations, strides, RoundType::CEIL,
padding_data.data(), dilations, strides, round_type,
output_shape.data());
}
......
......@@ -11,10 +11,10 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_
#define MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_
#ifndef MACE_OPS_OPENCL_IMAGE_REDUCE_H_
#define MACE_OPS_OPENCL_IMAGE_REDUCE_H_
#include "mace/ops/opencl/reduce_mean.h"
#include "mace/ops/opencl/reduce.h"
#include <memory>
#include <set>
......@@ -24,6 +24,7 @@
#include "mace/core/op_context.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/helper.h"
#include "mace/ops/reduce.h"
namespace mace {
namespace ops {
......@@ -31,11 +32,12 @@ namespace opencl {
namespace image {
template <typename T>
class ReduceMeanKernel : public OpenCLReduceMeanKernel {
class ReduceKernel : public OpenCLReduceKernel {
public:
ReduceMeanKernel(const std::vector<int> axis,
const bool keep_dims)
: axis_(axis), keep_dims_(keep_dims) {}
ReduceKernel(ReduceType type,
const std::vector<int> axis,
const bool keep_dims)
: reduce_type_(type), axis_(axis), keep_dims_(keep_dims) {}
MaceStatus Compute(
OpContext *context,
......@@ -43,6 +45,7 @@ class ReduceMeanKernel : public OpenCLReduceMeanKernel {
Tensor *output) override;
private:
ReduceType reduce_type_;
const std::vector<int> axis_;
bool keep_dims_;
cl::Kernel kernel_;
......@@ -51,16 +54,16 @@ class ReduceMeanKernel : public OpenCLReduceMeanKernel {
};
template <typename T>
MaceStatus ReduceMeanKernel<T>::Compute(
MaceStatus ReduceKernel<T>::Compute(
OpContext *context,
const Tensor *input,
Tensor *output) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
MACE_CHECK(input->dim_size() == 4,
"reduce mean gpu only support 4-dim input");
"reduce gpu only support 4-dim input");
MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2,
"reduce mean gpu only support 1,2-axis reduce");
"reduce gpu only support 1,2-axis reduce");
index_t batch = input->dim(0);
const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2);
......@@ -84,14 +87,15 @@ MaceStatus ReduceMeanKernel<T>::Compute(
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce_mean");
built_options.emplace("-Dreduce_mean=" + kernel_name);
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce");
built_options.emplace("-Dreduce=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt));
built_options.emplace(MakeString("-DREDUCE_TYPE=", reduce_type_));
if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) {
built_options.emplace("-DNON_QUALCOMM_ADRENO");
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("reduce_mean",
MACE_RETURN_IF_ERROR(runtime->BuildKernel("reduce",
kernel_name,
built_options,
&kernel_));
......@@ -170,4 +174,4 @@ MaceStatus ReduceMeanKernel<T>::Compute(
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_
#endif // MACE_OPS_OPENCL_IMAGE_REDUCE_H_
......@@ -36,6 +36,7 @@ class OpenCLPoolingKernel {
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
const RoundType round_type,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLPoolingKernel);
};
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_REDUCE_MEAN_H_
#define MACE_OPS_OPENCL_REDUCE_MEAN_H_
#ifndef MACE_OPS_OPENCL_REDUCE_H_
#define MACE_OPS_OPENCL_REDUCE_H_
#include "mace/public/mace.h"
#include "mace/utils/utils.h"
......@@ -24,16 +24,16 @@ class OpContext;
class Tensor;
namespace ops {
class OpenCLReduceMeanKernel {
class OpenCLReduceKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLReduceMeanKernel);
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLReduceKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_REDUCE_MEAN_H_
#endif // MACE_OPS_OPENCL_REDUCE_H_
......@@ -44,7 +44,7 @@ extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry);
extern void RegisterMatMul(OpRegistryBase *op_registry);
extern void RegisterPad(OpRegistryBase *op_registry);
extern void RegisterPooling(OpRegistryBase *op_registry);
extern void RegisterReduceMean(OpRegistryBase *op_registry);
extern void RegisterReduce(OpRegistryBase *op_registry);
extern void RegisterReshape(OpRegistryBase *op_registry);
extern void RegisterResizeBicubic(OpRegistryBase *op_registry);
extern void RegisterResizeBilinear(OpRegistryBase *op_registry);
......@@ -102,7 +102,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() {
ops::RegisterMatMul(this);
ops::RegisterPad(this);
ops::RegisterPooling(this);
ops::RegisterReduceMean(this);
ops::RegisterReduce(this);
ops::RegisterReshape(this);
ops::RegisterResizeBicubic(this);
ops::RegisterResizeBilinear(this);
......
......@@ -43,11 +43,14 @@ class PoolingOpBase : public ConvPool2dOpBase {
kernels_(Operation::GetRepeatedArgs<int>("kernels")),
pooling_type_(
static_cast<PoolingType>(Operation::GetOptionalArg<int>(
"pooling_type", static_cast<int>(AVG)))) {}
"pooling_type", static_cast<int>(AVG)))),
round_type_(static_cast<RoundType>(Operation::GetOptionalArg<int>(
"round_mode", static_cast<int>(CEIL)))) {}
protected:
std::vector<int> kernels_;
PoolingType pooling_type_;
RoundType round_type_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
......@@ -82,7 +85,7 @@ class PoolingOp<DeviceType::CPU, float> : public PoolingOpBase {
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::CEIL,
round_type_,
output_shape.data());
}
MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape));
......@@ -255,7 +258,7 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
paddings_.data(),
dilations_.data(),
strides_.data(),
RoundType::CEIL,
round_type_,
output_shape.data());
}
MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape));
......@@ -442,7 +445,7 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase {
return kernel_->Compute(context, input, pooling_type_, kernels_.data(),
strides_.data(), padding_type_, paddings_,
dilations_.data(), output);
dilations_.data(), round_type_, output);
}
private:
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/reduce.h"
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/operator.h"
#include "mace/core/runtime/cpu/cpu_runtime.h"
#include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/reduce.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
class ReduceOpBase : public Operation {
public:
explicit ReduceOpBase(OpConstructContext *context)
: Operation(context),
reduce_type_(
static_cast<ReduceType>(Operation::GetOptionalArg<int>(
"reduce_type", static_cast<int>(MEAN)))),
axis_(Operation::GetRepeatedArgs<int>("axis")),
keep_dims_(Operation::GetOptionalArg<bool>("keepdims", false)) {
}
protected:
inline void Validate() {
const Tensor *input = this->Input(0);
const int left = static_cast<int>(input->dim_size() * -1);
const int right = static_cast<int>(input->dim_size());
if (axis_.size()) {
for (unsigned int i = 0; i < axis_.size(); ++i) {
MACE_CHECK(axis_[i] > left && axis_[i] < right, "Axis is over range.");
}
}
}
protected:
ReduceType reduce_type_;
std::vector<int> axis_;
bool keep_dims_;
};
template <DeviceType D, class T>
class ReduceOp;
template <typename T>
class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
public:
explicit ReduceOp(OpConstructContext *context)
: ReduceOpBase(context) {}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
Validate();
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
Simplify(input);
output->Resize(out_shape_);
Compute(input, output);
return MaceStatus::MACE_SUCCESS;
}
private:
void Simplify(const Tensor *input) {
std::vector<bool> bitmap(static_cast<uint32_t>(input->dim_size()), false);
if (axis_.size() == 0) {
for (int i = 0; i < input->dim_size(); ++i) {
bitmap[i] = true;
}
} else {
for (unsigned int i = 0; i < axis_.size(); ++i) {
int index = axis_[i] >= 0 ?
axis_[i] :
axis_[i] + input->dim_size();
if (input->dim_size() == 4) {
if (index == 1 || index == 2) index = index + 1;
else if (index == 3) index = 1;
}
bitmap[index] = true;
}
}
out_shape_.clear();
for (unsigned int i = 0; i < input->dim_size(); ++i) {
if (!bitmap[i]) {
out_shape_.push_back(input->dim(i));
} else if (keep_dims_) {
out_shape_.push_back(1);
}
}
data_reshape_.clear();
unsigned int dim_index = 0;
for (; dim_index < input->dim_size(); ++dim_index) {
if (input->dim(dim_index) != 1) break;
}
if (dim_index >= input->dim_size()) {
reduce_first_axis_ = true;
} else {
reduce_first_axis_ = bitmap[dim_index];
data_reshape_.push_back(input->dim(dim_index));
++dim_index;
for (; dim_index < input->dim_size(); ++dim_index) {
const int n = input->dim(dim_index);
if (n == 1) {
bitmap[dim_index] = bitmap[dim_index - 1];
}
if (bitmap[dim_index-1] != bitmap[dim_index]) {
data_reshape_.push_back(n);
} else {
data_reshape_.back() *= n;
}
}
}
}
void compute_reduce_1(const T *input, ReduceType type, T *output) {
if (reduce_first_axis_) {
if (type == ReduceType::MEAN) {
T tmp = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
tmp = tmp + input[i];
}
output[0] = tmp / data_reshape_[0];
} else if (type == ReduceType::MIN) {
T tmp = input[0];
for (int i = 1; i < data_reshape_[0]; ++i) {
tmp = std::min<T>(tmp, input[i]);
}
output[0] = tmp;
} else if (type == ReduceType::MAX) {
T tmp = input[0];
for (int i = 1; i < data_reshape_[0]; ++i) {
tmp = std::max<T>(tmp, input[i]);
}
output[0] = tmp;
} else if (type == ReduceType::PROD) {
T tmp = input[0];
for (int i = 1; i < data_reshape_[0]; ++i) {
tmp = tmp * input[i];
}
output[0] = tmp;
} else {
MACE_NOT_IMPLEMENTED;
}
} else {
memcpy(output, input, data_reshape_[0] * sizeof(T));
}
}
void compute_reduce_2(const T *input, ReduceType type, T *output) {
if (reduce_first_axis_) {
if (type == ReduceType::MEAN) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = 0;
for (int j = 0; j < data_reshape_[0]; ++j) {
tmp += input[j * data_reshape_[1] + i];
}
output[i] = tmp / data_reshape_[0];
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = input[i];
for (int j = 1; j < data_reshape_[0]; ++j) {
tmp = std::min(tmp, input[j * data_reshape_[1] + i]);
}
output[i] = tmp;
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = input[i];
for (int j = 1; j < data_reshape_[0]; ++j) {
tmp = std::max(tmp, input[j * data_reshape_[1] + i]);
}
output[i] = tmp;
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = input[i];
for (int j = 1; j < data_reshape_[0]; ++j) {
tmp = tmp * input[j * data_reshape_[1] + i];
}
output[i] = tmp;
}
} else {
MACE_NOT_IMPLEMENTED;
}
} else {
if (type == ReduceType::MEAN) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
T tmp = 0;
for (int j = 0; j < data_reshape_[1]; ++j) {
tmp += input[i * data_reshape_[1] + j];
}
output[i] = tmp / data_reshape_[1];
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
T tmp = input[i * data_reshape_[1]];
for (int j = 1; j < data_reshape_[1]; ++j) {
tmp = std::min(tmp, input[i * data_reshape_[1] + j]);
}
output[i] = tmp;
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
T tmp = input[i * data_reshape_[1]];
for (int j = 1; j < data_reshape_[1]; ++j) {
tmp = std::max(tmp, input[i * data_reshape_[1] + j]);
}
output[i] = tmp;
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
T tmp = input[i * data_reshape_[1]];
for (int j = 1; j < data_reshape_[1]; ++j) {
tmp = tmp * input[i * data_reshape_[1] + j];
}
output[i] = tmp;
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
}
void compute_reduce_3(const T *input, ReduceType type, T *output) {
if (reduce_first_axis_) {
if (type == ReduceType::MEAN) {
#pragma omp parallel for collapse(1) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
output[i] +=
input[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
output[i] /= (data_reshape_[0] * data_reshape_[2]);
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for collapse(1) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = input[i * data_reshape_[2]];
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
tmp = std::min(tmp,
input[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j]);
}
}
output[i] = tmp;
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for collapse(1) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = input[i * data_reshape_[2]];
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
tmp =
std::max(tmp,
input[(k * data_reshape_[1] + i)
* data_reshape_[2] + j]);
}
}
output[i] = tmp;
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
T tmp = 1;
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
tmp *=
input[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
output[i] = tmp;
}
} else {
MACE_NOT_IMPLEMENTED;
}
} else {
if (type == ReduceType::MEAN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
output[i * data_reshape_[2] + j] +=
input[(i * data_reshape_[1] + k) * data_reshape_[2]
+ j];
}
output[i * data_reshape_[2] + j] /= data_reshape_[1];
}
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j];
for (int k = 1; k < data_reshape_[1]; ++k) {
tmp = std::min(tmp,
input[(i * data_reshape_[1] + k) *
data_reshape_[2] + j]);
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j];
for (int k = 1; k < data_reshape_[1]; ++k) {
tmp = std::max(tmp,
input[(i * data_reshape_[1] + k) *
data_reshape_[2] + j]);
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j];
for (int k = 1; k < data_reshape_[1]; ++k) {
tmp *= input[(i * data_reshape_[1] + k) *
data_reshape_[2] + j];
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
}
void compute_reduce_4(const T *input, ReduceType type, T *output) {
if (reduce_first_axis_) {
if (type == ReduceType::MEAN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
output[i * data_reshape_[3] + j] +=
input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j];
}
}
output[i * data_reshape_[3] + j] /=
(data_reshape_[0] * data_reshape_[2]);
}
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
T tmp = input[i * data_reshape_[2] * data_reshape_[3] + j];
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
tmp = std::min(tmp,
input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j]);
}
}
output[i * data_reshape_[3] + j] = tmp;
}
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
T tmp = input[i * data_reshape_[2] * data_reshape_[3] + j];
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
tmp = std::max(tmp,
input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j]);
}
}
output[i * data_reshape_[3] + j] = tmp;
}
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
T tmp = 1;
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
tmp = tmp * input[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j];
}
}
output[i * data_reshape_[3] + j] = tmp;
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
} else {
if (type == ReduceType::MEAN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
output[i * data_reshape_[2] + j] +=
input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t];
}
}
output[i * data_reshape_[2] + j] /=
(data_reshape_[1] * data_reshape_[3]);
}
}
} else if (type == ReduceType::MIN) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = input[(i * data_reshape_[1] *
data_reshape_[2] + j)*data_reshape_[3]];
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
tmp =
std::min(tmp,
input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t]);
}
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::MAX) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = input[(i * data_reshape_[1] *
data_reshape_[2] + j)*data_reshape_[3]];
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
tmp =
std::max(tmp,
input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t]);
}
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else if (type == ReduceType::PROD) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
T tmp = 1;
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
tmp = tmp * input[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t];
}
}
output[i * data_reshape_[2] + j] = tmp;
}
}
} else {
MACE_NOT_IMPLEMENTED;
}
}
}
void Compute(const Tensor *input, Tensor *output) {
Tensor::MappingGuard input_mapper(input);
const T *input_ptr = input->data<T>();
Tensor::MappingGuard output_map(output);
T *output_ptr = output->mutable_data<T>();
memset(output_ptr, 0, output->size() * sizeof(T));
switch (data_reshape_.size()) {
case 1:
compute_reduce_1(input_ptr, reduce_type_, output_ptr);
break;
case 2:
compute_reduce_2(input_ptr, reduce_type_, output_ptr);
break;
case 3:
compute_reduce_3(input_ptr, reduce_type_, output_ptr);
break;
case 4:
compute_reduce_4(input_ptr, reduce_type_, output_ptr);
break;
default:
MACE_CHECK(false, "not implemented in mace")
<< "data reshape size" << data_reshape_.size()
<< "reduce first axis:" << reduce_first_axis_;
break;
}
}
private:
bool reduce_first_axis_;
std::vector<int> data_reshape_;
std::vector<index_t> out_shape_;
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
class ReduceOp<DeviceType::GPU, T> : public ReduceOpBase {
public:
explicit ReduceOp(OpConstructContext *context)
: ReduceOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ReduceKernel<T>(reduce_type_,
axis_,
keep_dims_));
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
Validate();
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLReduceKernel> kernel_;
};
#endif // MACE_ENABLE_OPENCL
void RegisterReduce(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::CPU, float);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::GPU, float);
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_REDUCE_H_
#define MACE_OPS_REDUCE_H_
namespace mace {
enum ReduceType {
// SUM = 0,
MEAN = 0,
MIN = 1,
MAX = 2,
PROD = 3,
// SUM_SQR = 4,
// SQR_MEAN = 5,
};
} // namespace mace
#endif // MACE_OPS_REDUCE_H_
......@@ -21,7 +21,7 @@ namespace test {
namespace {
template <DeviceType D, typename T>
void ReduceMean(int iters, int batch, int channels,
void Reduce(int iters, int batch, int channels,
int height, int width) {
mace::testing::StopTiming();
......@@ -34,7 +34,7 @@ void ReduceMean(int iters, int batch, int channels,
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
}
OpDefBuilder("ReduceMean", "ReduceMeanBM")
OpDefBuilder("Reduce", "ReduceBM")
.Input("Input")
.AddIntsArg("axis", axis)
.Output("OutputImage")
......@@ -55,30 +55,30 @@ void ReduceMean(int iters, int batch, int channels,
}
} // namespace
#define MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, TYPE, DEVICE) \
#define MACE_BM_REDUCE_MACRO(N, C, H, W, TYPE, DEVICE) \
static void \
MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\
MACE_BM_REDUCE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
ReduceMean<DEVICE, TYPE>(iters, N, C, H, W); \
Reduce<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK( \
MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
MACE_BM_REDUCE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_REDUCE_MEAN(N, C, H, W) \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, GPU); \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, half, GPU); \
MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, CPU);
#define MACE_BM_REDUCE(N, C, H, W) \
MACE_BM_REDUCE_MACRO(N, C, H, W, float, GPU); \
MACE_BM_REDUCE_MACRO(N, C, H, W, half, GPU); \
MACE_BM_REDUCE_MACRO(N, C, H, W, float, CPU);
MACE_BM_REDUCE_MEAN(1, 1, 512, 512);
MACE_BM_REDUCE_MEAN(4, 3, 128, 128);
MACE_BM_REDUCE_MEAN(4, 1, 512, 512);
MACE_BM_REDUCE_MEAN(16, 32, 112, 112);
MACE_BM_REDUCE_MEAN(8, 64, 256, 256);
MACE_BM_REDUCE_MEAN(1, 32, 480, 640);
MACE_BM_REDUCE(1, 1, 512, 512);
MACE_BM_REDUCE(4, 3, 128, 128);
MACE_BM_REDUCE(4, 1, 512, 512);
MACE_BM_REDUCE(16, 32, 112, 112);
MACE_BM_REDUCE(8, 64, 256, 256);
MACE_BM_REDUCE(1, 32, 480, 640);
} // namespace test
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/operator.h"
#include "mace/core/tensor.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/reduce_mean.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
class ReduceMeanOpBase : public Operation {
public:
explicit ReduceMeanOpBase(OpConstructContext *context)
: Operation(context),
axis_(Operation::GetRepeatedArgs<int>("axis")),
keep_dims_(Operation::GetOptionalArg<bool>("keepdims", false)) {
}
protected:
inline void Validate() {
const Tensor *input = this->Input(0);
const int left = static_cast<int>(input->dim_size() * -1);
const int right = static_cast<int>(input->dim_size());
if (axis_.size()) {
for (unsigned int i = 0; i < axis_.size(); ++i) {
MACE_CHECK(axis_[i] > left && axis_[i] < right, "Axis is over range.");
}
}
}
protected:
std::vector<int> axis_;
bool keep_dims_;
};
template <DeviceType D, class T>
class ReduceMeanOp;
template <typename T>
class ReduceMeanOp<DeviceType::CPU, T> : public ReduceMeanOpBase {
public:
explicit ReduceMeanOp(OpConstructContext *context)
: ReduceMeanOpBase(context) {
}
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
Validate();
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
Simplify(input);
output->Resize(out_shape_);
Compute(input, output);
return MaceStatus::MACE_SUCCESS;
}
private:
void Simplify(const Tensor *input) {
std::vector<bool> bitmap(static_cast<uint32_t>(input->dim_size()), false);
if (axis_.size() == 0) {
for (int i = 0; i < input->dim_size(); ++i) {
bitmap[i] = true;
}
} else {
for (unsigned int i = 0; i < axis_.size(); ++i) {
int index = axis_[i] >= 0 ?
axis_[i] :
axis_[i] + input->dim_size();
// axis format is NHWC
if (input->dim_size() == 4) {
if (index == 1) index = 2;
else if (index == 2) index = 3;
else if (index == 3) index = 1;
}
bitmap[index] = true;
}
}
out_shape_.clear();
for (unsigned int i = 0; i < input->dim_size(); ++i) {
if (!bitmap[i]) {
out_shape_.push_back(input->dim(i));
} else if (keep_dims_) {
out_shape_.push_back(1);
}
}
data_reshape_.clear();
unsigned int dim_index = 0;
for (; dim_index < input->dim_size(); ++dim_index) {
if (input->dim(dim_index) != 1) break;
}
if (dim_index >= input->dim_size()) {
reduce_first_axis_ = true;
} else {
reduce_first_axis_ = bitmap[dim_index];
data_reshape_.push_back(input->dim(dim_index));
++dim_index;
for (; dim_index < input->dim_size(); ++dim_index) {
const int n = input->dim(dim_index);
if (n == 1) {
bitmap[dim_index] = bitmap[dim_index - 1];
}
if (bitmap[dim_index-1] != bitmap[dim_index]) {
data_reshape_.push_back(n);
} else {
data_reshape_.back() *= n;
}
}
}
}
void Compute(const Tensor *input, Tensor *output) {
Tensor::MappingGuard input_mapper(input);
const T *input_ptr = input->data<T>();
Tensor::MappingGuard output_map(output);
T *output_ptr = output->mutable_data<T>();
memset(output_ptr, 0, output->size() * sizeof(T));
switch (data_reshape_.size()) {
case 1:
if (reduce_first_axis_) {
T sum = 0;
for (int i = 0; i < data_reshape_[0]; ++i) {
sum = sum + input_ptr[i];
}
output_ptr[0] = sum / data_reshape_[0];
} else {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
output_ptr[i] = input_ptr[i];
}
}
break;
case 2:
if (reduce_first_axis_) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[0]; ++j) {
output_ptr[i] += input_ptr[j * data_reshape_[1] + i];
}
output_ptr[i] /= data_reshape_[0];
}
} else {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[1]; ++j) {
output_ptr[i] += input_ptr[i * data_reshape_[1] + j];
}
output_ptr[i] /= data_reshape_[1];
}
}
break;
case 3:
if (reduce_first_axis_) {
#pragma omp parallel for schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[0]; ++k) {
output_ptr[i] +=
input_ptr[(k * data_reshape_[1] + i) * data_reshape_[2]
+ j];
}
}
output_ptr[i] /= (data_reshape_[0] * data_reshape_[2]);
}
} else {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
output_ptr[i * data_reshape_[2] + j] +=
input_ptr[(i * data_reshape_[1] + k) * data_reshape_[2]
+ j];
}
output_ptr[i * data_reshape_[2] + j] /= data_reshape_[1];
}
}
}
break;
case 4:
if (reduce_first_axis_) {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[1]; ++i) {
for (int j = 0; j < data_reshape_[3]; ++j) {
for (int k = 0; k < data_reshape_[2]; ++k) {
for (int t = 0; t < data_reshape_[0]; ++t) {
output_ptr[i * data_reshape_[3] + j] +=
input_ptr[((t * data_reshape_[1] + i) *
data_reshape_[2] + k)*data_reshape_[3] + j];
}
}
output_ptr[i * data_reshape_[3] + j] /=
(data_reshape_[0] * data_reshape_[2]);
}
}
} else {
#pragma omp parallel for collapse(2) schedule(runtime)
for (int i = 0; i < data_reshape_[0]; ++i) {
for (int j = 0; j < data_reshape_[2]; ++j) {
for (int k = 0; k < data_reshape_[1]; ++k) {
for (int t = 0; t < data_reshape_[3]; ++t) {
output_ptr[i * data_reshape_[2] + j] +=
input_ptr[((i * data_reshape_[1] + k) *
data_reshape_[2] + j)*data_reshape_[3] + t];
}
}
output_ptr[i * data_reshape_[2] + j] /=
(data_reshape_[1] * data_reshape_[3]);
}
}
}
break;
default:
MACE_CHECK(false, "not implemented in mace")
<< "data reshape size" << data_reshape_.size()
<< "reduce first axis:" << reduce_first_axis_;
break;
}
}
private:
bool reduce_first_axis_;
std::vector<int> data_reshape_;
std::vector<index_t> out_shape_;
};
#ifdef MACE_ENABLE_OPENCL
template <typename T>
class ReduceMeanOp<DeviceType::GPU, T> : public ReduceMeanOpBase {
public:
explicit ReduceMeanOp(OpConstructContext *context)
: ReduceMeanOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
kernel_.reset(new opencl::image::ReduceMeanKernel<T>(axis_, keep_dims_));
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
Validate();
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
return kernel_->Compute(context, input, output);
}
private:
std::unique_ptr<OpenCLReduceMeanKernel> kernel_;
};
#endif // MACE_ENABLE_OPENCL
void RegisterReduceMean(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp,
DeviceType::CPU, float);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp,
DeviceType::GPU, float);
MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReduceMeanOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int> &axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output,
const bool keepdims = true) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("InputNCHW")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else {
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = net.CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
template <DeviceType D>
void Simple3D(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int> &axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output,
const bool keepdims = true) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
auto expected = net.CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
template <DeviceType D>
void Simple12Test() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2},
{2, 1, 1, 4},
{10, 11, 12, 13,
10, 11, 12, 13});
}
template <DeviceType D>
void Simple1Axis() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1},
{2, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{-3},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{-1},
{1, 2, 3, 1},
{1.5, 5.5, 9.5, 13.5, 17.5, 21.5});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1},
{1, 1, 3, 3},
{9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{-2},
{1, 3, 1, 3},
{3, 4, 5, 12, 13, 14, 21, 22, 23});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{3},
{1, 3, 3, 1},
{1, 4, 7, 10, 13, 16, 19, 22, 25});
}
template <DeviceType D>
void Simple2Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 3},
{1, 1, 3, 1},
{7.5, 11.5, 15.5});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1, 2},
{1, 1, 1, 3},
{12, 13, 14});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1},
{1, 1, 3, 3},
{9, 10, 11, 12, 13, 14, 15, 16, 17});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{2, 3},
{1, 3, 1, 1},
{4, 13, 22});
}
template <DeviceType D>
void Simple2Axis3D() {
Simple3D<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1},
{1, 1, 4},
{10, 11, 12, 13});
Simple3D<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2},
{2, 1, 1},
{5.5, 17.5});
}
template <DeviceType D>
void Simple3Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2, 3},
{1, 1, 1, 1},
{11.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 2, 3},
{1, 2, 1, 1},
{5.5, 17.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1, 3},
{1, 1, 3, 1},
{7.5, 11.5, 15.5});
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1, 2},
{1, 1, 1, 4},
{10, 11, 12, 13});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{1, 2, 3},
{1, 1, 1, 1},
{13});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 2, 3},
{1, 3, 1, 1},
{4, 13, 22});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1, 3},
{1, 1, 3, 1},
{10, 13, 16});
Simple<D>({1, 3, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16, 17,
18, 19, 20, 21, 22, 23, 24, 25, 26},
{0, 1, 2},
{1, 1, 1, 3},
{12, 13, 14});
}
} // namespace
TEST_F(ReduceMeanOpTest, CPUSimple12) {
Simple12Test<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, GPUSimple12) {
Simple12Test<DeviceType::GPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple1Axis) {
Simple1Axis<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple2Axis) {
Simple2Axis<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple2Axis3D) {
Simple2Axis3D<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimple3Axis) {
Simple3Axis<DeviceType::CPU>();
}
TEST_F(ReduceMeanOpTest, CPUSimpleReduceDims) {
Simple3D<CPU>({2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{0, 1},
{4},
{10, 11, 12, 13},
false);
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const std::vector<index_t> &input_shape,
const std::vector<int> &axis) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("InputNCHW")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
OpDefBuilder("ReduceMean", "ReduceMeanTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.Output("OPENCLOutput")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-5, 1e-4);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-2, 1e-2);
}
}
} // namespace
TEST_F(ReduceMeanOpTest, GPURandomFloat) {
RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2});
RandomTest<DeviceType::GPU, float>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 480, 640, 32}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2});
}
TEST_F(ReduceMeanOpTest, GPURandomHalf) {
RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2});
RandomTest<DeviceType::GPU, half>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 480, 640, 32}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2});
}
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mace/ops/reduce.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ReduceOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int> &axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output,
ReduceType type,
const bool keepdims = true) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW);
OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC);
} else {
OpDefBuilder("Reduce", "ReduceTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
}
auto expected = net.CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
template <DeviceType D>
void Simple3D(const std::vector<index_t> &input_shape,
const std::vector<float> &input,
const std::vector<int> &axis,
const std::vector<index_t> &output_shape,
const std::vector<float> &output,
ReduceType type,
const bool keepdims = true) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input);
OpDefBuilder("Reduce", "ReduceTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
auto expected = net.CreateTensor<float>(output_shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
template <DeviceType D>
void SimpleMean12Test() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2},
{2, 1, 1, 4},
{10, 11, 12, 13,
10, 11, 12, 13}, ReduceType::MEAN);
}
// template <DeviceType D>
// void SimpleSum12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {60, 66, 72, 78,
// 60, 66, 72, 78}, ReduceType::SUM);
//}
template <DeviceType D>
void SimpleMin12Test() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2},
{2, 1, 1, 4},
{0, 1, 2, 3,
0, 1, 2, 3}, ReduceType::MIN);
}
template <DeviceType D>
void SimpleMax12Test() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1, 2},
{2, 1, 1, 4},
{20, 21, 22, 23,
20, 21, 22, 23}, ReduceType::MAX);
}
// template <DeviceType D>
// void SimpleSumSqr12Test() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1, 2},
// {2, 1, 1, 4},
// {880, 1006, 1144, 1294,
// 880, 1006, 1144, 1294}, ReduceType::SUM_SQR);
//}
template <DeviceType D>
void SimpleMean1Axis() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{1},
{2, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,
6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-3},
// {1, 1, 3, 4},
// {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {-2},
// {1, 3, 1, 3},
// {3, 4, 5, 12, 13, 14, 21, 22, 23}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {1, 4, 7, 10, 13, 16, 19, 22, 25}, ReduceType::MEAN);
}
// template <DeviceType D>
// void SimpleSum1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23,
// 0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34,
// 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34});
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {12, 15, 18, 21, 48, 51, 54, 57}, ReduceType::SUM);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {6, 22, 38, 54, 70, 86}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {27, 30, 33, 36, 39, 42, 45, 48, 51}, ReduceType::SUM);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {3},
// {1, 3, 3, 1},
// {3, 12, 21, 30, 39, 48, 57, 66, 75}, ReduceType::SUM);
//}
template <DeviceType D>
void SimpleMin1Axis() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1},
{2, 1, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {0, 1, 2, 3, 12, 13, 14, 15}, ReduceType::MIN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {0, 4, 8, 12, 16, 20}, ReduceType::MIN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8}, ReduceType::MIN);
}
template <DeviceType D>
void SimpleMax1Axis() {
Simple<D>({2, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1},
{2, 1, 3, 4},
{12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {2},
// {1, 2, 1, 4},
// {8, 9, 10, 11, 20, 21, 22, 23}, ReduceType::MAX);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {-1},
// {1, 2, 3, 1},
// {3, 7, 11, 15, 19, 23}, ReduceType::MAX);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1},
// {1, 1, 3, 3},
// {18, 19, 20, 21, 22, 23, 24, 25, 26}, ReduceType::MAX);
}
// template <DeviceType D>
// void SimpleSumSqr1Axis() {
// Simple<D>({2, 2, 3, 4},
// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
// {1},
// {2, 1, 3, 4},
// {144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650,
// 144, 170, 200, 234,
// 272, 314, 360, 410,
// 464, 522, 584, 650}, ReduceType::SUM_SQR);
//}
template <DeviceType D>
void Simple2Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 1},
{1, 1, 3, 4},
{6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple3D<D>({2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1},
// {1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
Simple3D<D>({2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2},
{2, 1, 1},
{5.5, 17.5}, ReduceType::MEAN);
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{0, 2},
{1, 2, 1, 4},
{4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1},
// {1, 1, 3, 3},
// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
}
template <DeviceType D>
void Simple3Axis() {
Simple<D>({1, 2, 3, 4},
{0, 1, 2, 3,
4, 5, 6, 7,
8, 9, 10, 11,
12, 13, 14, 15,
16, 17, 18, 19,
20, 21, 22, 23},
{1, 2, 3},
{1, 1, 1, 1},
{11.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 2, 3},
// {1, 2, 1, 1},
// {5.5, 17.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 3},
// {1, 1, 3, 1},
// {7.5, 11.5, 15.5}, ReduceType::MEAN);
// Simple<D>({1, 2, 3, 4},
// {0, 1, 2, 3,
// 4, 5, 6, 7,
// 8, 9, 10, 11,
// 12, 13, 14, 15,
// 16, 17, 18, 19,
// 20, 21, 22, 23},
// {0, 1, 2},
// {1, 1, 1, 4},
// {10, 11, 12, 13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {1, 2, 3},
// {1, 1, 1, 1},
// {13}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 2, 3},
// {1, 3, 1, 1},
// {4, 13, 22}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 3},
// {1, 1, 3, 1},
// {10, 13, 16}, ReduceType::MEAN);
// Simple<D>({1, 3, 3, 3},
// {0, 1, 2, 3, 4, 5, 6, 7, 8,
// 9, 10, 11, 12, 13, 14, 15, 16, 17,
// 18, 19, 20, 21, 22, 23, 24, 25, 26},
// {0, 1, 2},
// {1, 1, 1, 3},
// {12, 13, 14}, ReduceType::MEAN);
}
} // namespace
TEST_F(ReduceOpTest, CPUSimple12) {
SimpleMean12Test<DeviceType::CPU>();
SimpleMin12Test<DeviceType::CPU>();
SimpleMax12Test<DeviceType::CPU>();
}
TEST_F(ReduceOpTest, GPUSimple12) {
SimpleMean12Test<DeviceType::GPU>();
SimpleMin12Test<DeviceType::GPU>();
SimpleMax12Test<DeviceType::GPU>();
}
TEST_F(ReduceOpTest, CPUSimple1Axis) {
SimpleMean1Axis<DeviceType::CPU>();
SimpleMin1Axis<DeviceType::CPU>();
SimpleMax1Axis<DeviceType::CPU>();
}
TEST_F(ReduceOpTest, CPUSimple2Axis) {
Simple2Axis<DeviceType::CPU>();
}
TEST_F(ReduceOpTest, CPUSimple3Axis) {
Simple3Axis<DeviceType::CPU>();
}
TEST_F(ReduceOpTest, CPUSimpleReduceDims) {
Simple3D<CPU>({2, 3, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
{0, 1},
{4},
{10, 11, 12, 13}, ReduceType::MEAN,
false);
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const std::vector<index_t> &input_shape,
const std::vector<int> &axis) {
testing::internal::LogToStderr();
srand(time(NULL));
auto func = [&](ReduceType type) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW",
NCHW);
OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW,
"Output", NHWC);
OpDefBuilder("Reduce", "ReduceTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type)
.Output("OPENCLOutput")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-5, 1e-4);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-2, 1e-2);
}
};
for (ReduceType type : {MEAN, MIN, MAX, PROD}) {
func(type);
}
}
} // namespace
TEST_F(ReduceOpTest, GPURandomFloat) {
RandomTest<DeviceType::GPU, float>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, float>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, float>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, float>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, float>({1, 511, 561, 11}, {1, 2});
}
TEST_F(ReduceOpTest, GPURandomHalf) {
RandomTest<DeviceType::GPU, half>({4, 64, 64, 3}, {1, 2});
// RandomTest<DeviceType::GPU, half>({2, 64, 64, 4}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 128, 128, 64}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 640, 480, 64}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 480, 640, 32}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 512, 512, 16}, {1, 2});
RandomTest<DeviceType::GPU, half>({8, 117, 87, 33}, {1, 2});
// RandomTest<DeviceType::GPU, half>({1, 619, 450, 61}, {1, 2});
RandomTest<DeviceType::GPU, half>({1, 511, 561, 11}, {1, 2});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -36,6 +36,7 @@ class ReshapeOp : public Operation {
int unknown_idx = -1;
index_t product = 1;
std::vector<index_t> out_shape;
index_t n = 0;
for (int i = 0; i < num_dims; ++i) {
if (shape_data[i] == -1) {
......@@ -45,8 +46,15 @@ class ReshapeOp : public Operation {
} else {
MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
shape_data[i]);
out_shape.push_back(shape_data[i]);
product *= shape_data[i];
if (shape_data[i] == 0) {
MACE_CHECK(i < input->dim_size(),
"dims:0 out of input dims' range.");
n = input->dim(i);
} else {
n = shape_data[i];
}
out_shape.push_back(n);
product *= n;
}
}
......
......@@ -13,6 +13,7 @@ py_library(
"converter_tool/base_converter.py",
"converter_tool/caffe_converter.py",
"converter_tool/hexagon_converter.py",
"converter_tool/onnx_converter.py",
"converter_tool/shape_inference.py",
"converter_tool/tensorflow_converter.py",
"converter_tool/tf_dsp_converter.py",
......
......@@ -101,7 +101,7 @@ def main(unused_args):
file=sys.stderr)
sys.exit(-1)
if FLAGS.platform not in ['tensorflow', 'caffe']:
if FLAGS.platform not in ['tensorflow', 'caffe', 'onnx']:
six.print_("platform %s is not supported." % FLAGS.platform,
file=sys.stderr)
sys.exit(-1)
......@@ -188,6 +188,9 @@ def main(unused_args):
converter = caffe_converter.CaffeConverter(option,
FLAGS.model_file,
FLAGS.weight_file)
elif FLAGS.platform == 'onnx':
from mace.python.tools.converter_tool import onnx_converter
converter = onnx_converter.OnnxConverter(option, FLAGS.model_file)
else:
six.print_("Mace do not support platorm %s yet." % FLAGS.platform,
file=sys.stderr)
......@@ -231,6 +234,7 @@ def parse_args():
type=str,
default="",
help="TensorFlow \'GraphDef\' file to load, "
"Onnx model file .onnx to load, "
"Caffe prototxt file to load.")
parser.add_argument(
"--weight_file", type=str, default="", help="Caffe data file to load.")
......@@ -300,7 +304,10 @@ def parse_args():
parser.add_argument(
"--check_shape", type=str, default="", help="check shape.")
parser.add_argument(
"--platform", type=str, default="tensorflow", help="tensorflow/caffe")
"--platform",
type=str,
default="tensorflow",
help="tensorflow/caffe/onnx")
parser.add_argument(
"--embed_model_data",
type=str2bool,
......
......@@ -37,11 +37,14 @@ class FilterFormat(Enum):
OHWI = 103
# SAME_LOWER: if the amount of paddings to be added is odd,
# it will add the extra data to the right or bottom
class PaddingMode(Enum):
VALID = 0
SAME = 1
FULL = 2
NA = 3
SAME_LOWER = 3
NA = 4
class PoolingType(Enum):
......@@ -49,6 +52,11 @@ class PoolingType(Enum):
MAX = 2
class RoundMode(Enum):
FLOOR = 0
CEIL = 1
class ActivationType(Enum):
NOOP = 0
RELU = 1
......@@ -56,6 +64,7 @@ class ActivationType(Enum):
PRELU = 3
TANH = 4
SIGMOID = 5
LEAKYRELU = 6
class EltwiseType(Enum):
......@@ -72,9 +81,17 @@ class EltwiseType(Enum):
EQUAL = 10
class ReduceType(Enum):
MEAN = 0
MIN = 1
MAX = 2
PROD = 3
class FrameworkType(Enum):
TENSORFLOW = 0
CAFFE = 1
ONNX = 2
MaceSupportedOps = [
......@@ -108,7 +125,7 @@ MaceSupportedOps = [
'Pooling',
'Proposal',
'Quantize',
'ReduceMean',
'Reduce',
'Reshape',
'ResizeBicubic',
'ResizeBilinear',
......@@ -184,6 +201,10 @@ class MaceKeyword(object):
mace_group_str = "group"
mace_wino_arg_str = "wino_block_size"
mace_quantize_flag_arg_str = "quantize_flag"
mace_epsilon_str = 'epsilon'
mace_reduce_type_str = 'reduce_type'
mace_argmin_str = 'argmin'
mace_round_mode_str = 'round_mode'
class TransformerRule(Enum):
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from enum import Enum
import six
from mace.proto import mace_pb2
from mace.python.tools.converter_tool import base_converter
from mace.python.tools.converter_tool.base_converter import PoolingType
from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ActivationType
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import RoundMode
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.convert_util import mace_check
import onnx
import onnx.utils
from onnx import helper, shape_inference, numpy_helper, optimizer
import numpy as np
from onnx import mapping
from onnx import TensorProto
from numbers import Number
OnnxSupportedOps = [
'Abs',
# 'Acos',
# 'Acosh',
'Add',
# 'And',
'ArgMax',
'ArgMin',
# 'Asin',
# 'Asinh',
# 'Atan',
# 'Atanh',
'AveragePool',
'BatchNormalization',
'Cast',
# 'Ceil',
# 'Clip',
# 'Compress',
'Concat',
# 'Constant',
# 'ConstantLike',
'Conv',
'ConvTranspose',
# 'Cos',
# 'Cosh',
'DepthToSpace',
'Div',
'Dropout',
'Elu',
'Equal',
# 'Exp',
# 'Expand',
# 'EyeLike',
# 'Flatten',
# 'Floor',
# 'GRU',
'Gather',
'Gemm',
'GlobalAveragePool',
# 'GlobalLpPool',
'GlobalMaxPool',
# 'Greater',
# 'HardSigmoid',
# 'Hardmax',
'Identity',
# 'If',
'ImageScaler',
# 'InstanceNormalization',
# 'LRN',
# 'LSTM',
'LeakyRelu',
# 'Less',
# 'Log',
# 'LogSoftmax',
# 'Loop',
# 'LpNormalization',
# 'LpPool',
'MatMul',
'Max',
'MaxPool',
# 'MaxRoiPool',
# 'MaxUnpool',
'Mean',
'Min',
'Mul',
# 'Multinomial',
'Neg',
# 'Not',
# 'OneHot',
# 'Or',
'PRelu',
'Pad',
'Pow',
# 'RNN',
# 'RandomNormal',
# 'RandonNormalLike',
# 'RandonUniform',
# 'RandonUniformLike',
'Reciprocal',
# 'ReduceL1',
# 'ReduceL2',
# 'ReduceLogSum',
# 'ReduceLogSumExp',
'ReduceMax',
'ReduceMean',
'ReduceMin',
'ReduceProd',
# 'ReduceSum',
# 'ReduceSumSquare',
'Relu',
'Reshape',
# 'Scan',
# 'Selu',
'Shape',
'Sigmoid',
# 'Sin',
# 'Sinh',
# 'Size',
# 'Slice',
'Softmax',
# 'Softplus',
# 'Softsign',
'SpaceToDepth',
'Split',
'Sqrt',
'Squeeze',
'Sub',
'Sum',
# 'Tan',
'Tanh',
# 'Tile',
# 'TopK',
'Transpose',
# 'Unsqueeze',
# 'Upsample',
# 'Xor',
]
OnnxOpType = Enum('OnnxOpType',
[(op, op) for op in OnnxSupportedOps],
type=str)
onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
}
def translate_onnx(key, val):
return onnx_attr_translator.get(key, lambda x: x)(val)
def convert_onnx(attr):
return convert_onnx_attribute_proto(attr)
def convert_onnx_attribute_proto(attr_proto):
if attr_proto.HasField('f'):
return attr_proto.f
elif attr_proto.HasField('i'):
return attr_proto.i
elif attr_proto.HasField('s'):
return str(attr_proto.s, 'utf-8')\
if sys.version_info.major == 3 else attr_proto.s
elif attr_proto.HasField('t'):
return attr_proto.t # this is a proto!
elif attr_proto.floats:
return list(attr_proto.floats)
elif attr_proto.ints:
return list(attr_proto.ints)
elif attr_proto.strings:
str_list = list(attr_proto.strings)
if IS_PYTHON3:
str_list = map(lambda x: str(x, 'utf-8'), str_list)
return str_list
else:
raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto))
def onnx_dtype(dtype):
if isinstance(dtype, Number):
onnx_dtype = dtype
elif isinstance(dtype, str):
onnx_dtype = TensorProto.DataType.Value(dtype)
else:
raise RuntimeError("dtype should be number or str.")
return mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_dtype]
class OnnxNode(object):
def __init__(self, node):
self.name = str(node.name)
self.op_type = str(node.op_type)
self.domain = str(node.domain)
self.attrs = dict([(attr.name,
translate_onnx(attr.name, convert_onnx(attr)))
for attr in node.attribute])
self.inputs = list(node.input)
self.outputs = list(node.output)
self.node_proto = node
def print_info(self):
print "node: ", self.name
print " type: ", self.op_type
print " domain: ", self.domain
print " inputs: ", self.inputs
print " outputs: ", self.outputs
print " attrs:"
for arg in self.attrs:
print " %s: %s" % (arg, self.attrs[arg])
class OnnxTensor(object):
def __init__(self, name, value, shape, dtype):
self._name = name
self._tensor_data = value
self._shape = shape
self._dtype = dtype
class OnnxConverter(base_converter.ConverterInterface):
pooling_type_mode = {
OnnxOpType.AveragePool.name: PoolingType.AVG,
OnnxOpType.MaxPool.name: PoolingType.MAX
}
auto_pad_mode = {
'NOTSET': PaddingMode.NA,
'SAME_UPPER': PaddingMode.SAME,
'SAME_LOWER': PaddingMode.SAME,
'VALID': PaddingMode.VALID,
}
auto_pad_mode = {six.b(k): v for k, v in six.iteritems(auto_pad_mode)}
eltwise_type = {
OnnxOpType.Mul.name: EltwiseType.PROD,
OnnxOpType.Add.name: EltwiseType.SUM,
OnnxOpType.Max.name: EltwiseType.MAX,
OnnxOpType.Min.name: EltwiseType.MIN,
OnnxOpType.Abs.name: EltwiseType.ABS,
OnnxOpType.Pow.name: EltwiseType.POW,
OnnxOpType.Sub.name: EltwiseType.SUB,
OnnxOpType.Div.name: EltwiseType.DIV,
OnnxOpType.Neg.name: EltwiseType.NEG,
OnnxOpType.Sum.name: EltwiseType.SUM,
OnnxOpType.Equal.name: EltwiseType.EQUAL,
OnnxOpType.Sqrt.name: EltwiseType.POW,
OnnxOpType.Reciprocal.name: EltwiseType.POW,
}
reduce_type = {
OnnxOpType.GlobalAveragePool.name: ReduceType.MEAN,
OnnxOpType.GlobalMaxPool.name: ReduceType.MAX,
OnnxOpType.ReduceMax.name: ReduceType.MAX,
OnnxOpType.ReduceMean.name: ReduceType.MEAN,
OnnxOpType.ReduceMin.name: ReduceType.MIN,
OnnxOpType.ReduceProd.name: ReduceType.PROD,
}
activation_type = {
OnnxOpType.Relu.name: ActivationType.RELU,
OnnxOpType.PRelu.name: ActivationType.PRELU,
OnnxOpType.Tanh.name: ActivationType.TANH,
OnnxOpType.Sigmoid.name: ActivationType.SIGMOID,
OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU,
}
def __init__(self, option, src_model_file):
self._op_converters = {
OnnxOpType.Abs.name: self.convert_eltwise,
OnnxOpType.Add.name: self.convert_eltwise,
OnnxOpType.ArgMax.name: self.convert_argmax,
OnnxOpType.ArgMin.name: self.convert_argmax,
OnnxOpType.AveragePool.name: self.convert_pooling,
OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm,
OnnxOpType.Cast.name: self.convert_cast,
OnnxOpType.Concat.name: self.convert_concat,
OnnxOpType.Conv.name: self.convert_conv2d,
OnnxOpType.ConvTranspose.name: self.convert_deconv,
OnnxOpType.DepthToSpace.name: self.convert_depth_space,
OnnxOpType.Dropout.name: self.convert_identity,
OnnxOpType.Div.name: self.convert_eltwise,
OnnxOpType.Equal.name: self.convert_eltwise,
OnnxOpType.Gather.name: self.convert_gather,
OnnxOpType.Gemm.name: self.convert_fully_connected,
OnnxOpType.GlobalAveragePool.name: self.convert_reduce,
OnnxOpType.GlobalMaxPool.name: self.convert_reduce,
OnnxOpType.Identity.name: self.convert_identity,
OnnxOpType.ImageScaler.name: self.convert_imagescaler,
OnnxOpType.LeakyRelu.name: self.convert_activation,
OnnxOpType.Max.name: self.convert_eltwise,
OnnxOpType.MaxPool.name: self.convert_pooling,
OnnxOpType.MatMul.name: self.convert_matmul,
OnnxOpType.Min.name: self.convert_eltwise,
OnnxOpType.Mul.name: self.convert_eltwise,
OnnxOpType.Neg.name: self.convert_eltwise,
OnnxOpType.Pad.name: self.convert_pad,
OnnxOpType.Pow.name: self.convert_eltwise,
OnnxOpType.PRelu.name: self.convert_activation,
OnnxOpType.Relu.name: self.convert_activation,
OnnxOpType.Reshape.name: self.convert_reshape,
OnnxOpType.Reciprocal.name: self.convert_eltwise,
OnnxOpType.Sigmoid.name: self.convert_activation,
OnnxOpType.Softmax.name: self.convert_softmax,
OnnxOpType.SpaceToDepth.name: self.convert_depth_space,
OnnxOpType.Split.name: self.convert_split,
OnnxOpType.Sqrt.name: self.convert_eltwise,
OnnxOpType.Squeeze.name: self.convert_squeeze,
OnnxOpType.Sub.name: self.convert_eltwise,
OnnxOpType.Sum.name: self.convert_eltwise,
OnnxOpType.Tanh.name: self.convert_activation,
OnnxOpType.Transpose.name: self.convert_transpose,
}
self._option = option
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW)
onnx_model = onnx.load(src_model_file)
polished_model = onnx.utils.polish_model(onnx_model)
print "onnx model IR version: ", onnx_model.ir_version
print "onnx model opset import: ", onnx_model.opset_import
self._onnx_model = shape_inference.infer_shapes(polished_model)
self._graph_shapes_dict = {}
self._consts = {}
self._replace_tensors = {}
def print_graph_info(self, graph):
for value_info in graph.value_info:
print "value info:", value_info
for value_info in graph.input:
print "inputs info:", value_info
for value_info in graph.output:
print "outputs info:", value_info
def extract_shape_info(self, graph):
def extract_value_info(shape_dict, value_info):
t = tuple([int(dim.dim_value)
for dim in value_info.type.tensor_type.shape.dim])
if t:
shape_dict[value_info.name] = t
for value_info in graph.value_info:
extract_value_info(self._graph_shapes_dict, value_info)
for value_info in graph.input:
extract_value_info(self._graph_shapes_dict, value_info)
for value_info in graph.output:
extract_value_info(self._graph_shapes_dict, value_info)
def add_tensor(self, name, shape, data_type, value):
tensor = self._mace_net_def.tensors.add()
tensor.name = name
tensor.dims.extend(list(shape))
tensor.data_type = data_type
tensor.float_data.extend(value.flat)
def run(self):
graph_def = self._onnx_model.graph
self.extract_shape_info(graph_def)
self.convert_tensors(graph_def)
self.convert_ops(graph_def)
# self.print_graph_info(graph_def)
# shape_inferer = mace_shape_inference.ShapeInference(
# self._mace_net_def,
# self._option.input_nodes.values())
# shape_inferer.run()
return self._mace_net_def
def add_stride_pad_kernel_arg(self, attrs, op_def):
if 'strides' in attrs:
strides = attrs['strides']
mace_check(len(strides) == 2, "strides should has 2 values.")
stride = [strides[0], strides[1]]
else:
stride = [1, 1]
strides_arg = op_def.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend(stride)
if 'kernel_shape' in attrs:
kernel_shape = attrs['kernel_shape']
mace_check(len(kernel_shape) == 2,
"kernel shape should has 2 values.")
kernel = [kernel_shape[0], kernel_shape[1]]
kernels_arg = op_def.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel)
if 'pads' in attrs:
pads = attrs['pads']
if len(pads) == 4:
pad = [pads[0] + pads[2], pads[1] + pads[3]]
else:
pad = [0, 0]
padding_arg = op_def.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(pad)
elif 'auto_pad' in attrs:
auto_pad_arg = op_def.arg.add()
auto_pad_arg.name = MaceKeyword.mace_padding_str
auto_pad_arg.i = self.auto_pad_mode[attrs['auto_pad']].value
else:
pad = [0, 0]
padding_arg = op_def.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(pad)
def convert_ops(self, graph_def):
for n in graph_def.node:
node = OnnxNode(n)
mace_check(node.op_type in self._op_converters,
"Mace does not support onnx op type %s yet"
% node.op_type)
self._op_converters[node.op_type](node)
def convert_tensors(self, graph_def):
initializer = graph_def.initializer
if initializer:
for init in initializer:
tensor = self._mace_net_def.tensors.add()
tensor.name = init.name
onnx_tensor = numpy_helper.to_array(init)
tensor.dims.extend(list(init.dims))
data_type = onnx_dtype(init.data_type)
if data_type == np.float32 or data_type == np.float64:
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
else:
mace_check(False,
"Not supported tensor type: %s" % data_type)
self._consts[tensor.name] = tensor
def convert_general_op(self, node):
op = self._mace_net_def.op.add()
op.name = node.name
for input in node.inputs:
if input in self._replace_tensors:
input = self._replace_tensors[input]
op.input.append(input)
for output in node.outputs:
op.output.append(output)
output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output]
output_shape.dims.extend(shape_info)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
framework_type_arg = op.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = FrameworkType.ONNX.value
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
return op
def convert_fused_batchnorm(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
if "epsilon" in node.attrs:
epsilon_value = node.attrs["epsilon"]
else:
epsilon_value = 1e-5
mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.")
gamma_value = np.array(self._consts[node.inputs[1]].float_data)
beta_value = np.array(self._consts[node.inputs[2]].float_data)
mean_value = np.array(self._consts[node.inputs[3]].float_data)
var_value = np.array(self._consts[node.inputs[4]].float_data)
scale_name = node.name + 'scale'
offset_name = node.name + 'offset'
scale_value = (
(1.0 / np.sqrt(
var_value + epsilon_value)) * gamma_value)
offset_value = (-mean_value * scale_value) + beta_value
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT,
offset_value)
del op.input[1:]
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
def convert_conv2d(self, node):
op = self.convert_general_op(node)
self.add_stride_pad_kernel_arg(node.attrs, op)
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
if 'group' in node.attrs:
group_val = node.attrs["group"]
else:
group_val = 1
group_arg.i = group_val
is_depthwise = False
if group_val > 1:
filter_shape = self._graph_shapes_dict[node.inputs[1]]
mace_check(group_val == filter_shape[0] and
filter_shape[1] == 1,
"Mace does not support group convolution yet")
filter_tensor = self._consts[node.inputs[1]]
new_shape = [filter_shape[1], filter_shape[0],
filter_shape[2], filter_shape[3]]
del filter_tensor.dims[:]
filter_tensor.dims.extend(new_shape)
is_depthwise = True
if is_depthwise:
op.type = MaceOp.DepthwiseConv2d.name
else:
op.type = MaceOp.Conv2D.name
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if 'dilations' in node.attrs:
dilation_val = node.attrs["dilations"]
else:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
def convert_biasadd(self, node):
self.convert_general_op(node)
def convert_concat(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Concat.name
mace_check('axis' in node.attrs,
'Concat op should have axis attribute.')
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = node.attrs['axis']
axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i
mace_check(axis_arg.i == 1,
"only support concat at channel dimension")
def convert_activation(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Activation.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_activation_type_str
type_arg.s = six.b(self.activation_type[node.op_type].name)
if "alpha" in node.attrs:
alpha_value = node.attrs["alpha"]
else:
alpha_value = 0
alpha_arg = op.arg.add()
alpha_arg.name = MaceKeyword.mace_activation_max_limit_str
alpha_arg.f = alpha_value
def convert_pooling(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pooling.name
self.add_stride_pad_kernel_arg(node.attrs, op)
pooling_type_arg = op.arg.add()
pooling_type_arg.name = MaceKeyword.mace_pooling_type_str
pooling_type_arg.i = self.pooling_type_mode[node.op_type].value
round_mode_arg = op.arg.add()
round_mode_arg.name = MaceKeyword.mace_round_mode_str
round_mode_arg.i = RoundMode.FLOOR.value
def convert_reshape(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def convert_flatten(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reshape.name
def remove_node(self, node):
input_name = node.inputs[0]
output_name = node.outputs[0]
self._replace_tensors[output_name] = input_name
def convert_eltwise(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Eltwise.name
type_arg = op.arg.add()
type_arg.name = MaceKeyword.mace_element_type_str
type_arg.i = self.eltwise_type[node.op_type].value
if node.op_type == OnnxOpType.Sqrt.name:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = 0.5
elif node.op_type == OnnxOpType.Reciprocal.name:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = -1
def convert_reduce(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Reduce.name
reduce_type_arg = op.arg.add()
reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
reduce_type_arg.i = self.reduce_type[node.op_type].value
if node.op_type in [OnnxOpType.GlobalAveragePool.name,
OnnxOpType.GlobalMaxPool.name]:
reduce_dims = [2, 3]
keep_dims = 1
else:
if 'axes' in node.attrs:
reduce_dims = node.attrs['axes']
else:
reduce_dims = []
if 'keepdims' in node.attrs:
keep_dims = node.attrs['keepdims']
else:
keep_dims = 1
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keep_dims
def convert_imagescaler(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.BatchNorm.name
scale = node.attrs['scale']
bias_value = np.array(node.attrs['bias'])
scale_value = scale * np.ones_like(bias_value)
scale_name = node.name + "_scale"
bias_name = node.name + "_bias"
self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT,
scale_value)
self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT,
bias_value)
op.input.extend([scale_name, bias_name])
def convert_matmul(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.MatMul.name
def convert_softmax(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Softmax.name
def convert_argmax(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.ArgMax.name
if 'axis' in node.attrs:
axis_value = node.attrs['axis']
else:
axis_value = 0
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis_value
if 'keepdims' in node.attrs:
keepdims = node.attrs['keepdims']
else:
keepdims = 1
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keepdims
if node.op_type == OnnxOpType.ArgMin.name:
min_arg = op.arg.add()
min_arg.name = MaceKeyword.mace_argmin_str
min_arg.i = 1
def convert_cast(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Cast.name
if 'to' in node.attrs:
dtype = node.attrs['to']
if dtype == TensorProto.FLOAT:
op.output_type.extend([self._option.data_type])
elif dtype == TensorProto.INT:
op.output_type.extend([mace_pb2.DT_INT32])
else:
mace_check(False, "data type %s not supported" % dtype)
else:
op.output_type.extend([self._option.data_type])
def convert_depth_space(self, node):
op = self.convert_general_op(node)
if op.type == OnnxOpType.DepthToSpace.name:
op.type = MaceOp.DepthToSpace.name
else:
op.type = MaceOp.SpaceToDepth.name
mace_check(('block_size' in node.attrs),
"depth to space op should have block size attribute.")
block_size = node.attrs['block_size']
size_arg = op.arg.add()
size_arg.name = MaceKeyword.mace_space_depth_block_size_str
size_arg.i = block_size
def convert_deconv(self, node):
op = self.convert_general_op(node)
self.add_stride_pad_kernel_arg(node.attrs, op)
if 'group' in node.attrs:
group_val = node.attrs["group"]
else:
group_val = 1
if group_val > 1:
op.type = MaceOp.DepthwiseDeconv2d.name
filter_shape = self._graph_shapes_dict[node.inputs[1]]
filter_tensor = self._consts[node.inputs[1]]
new_shape = [filter_shape[1], filter_shape[0],
filter_shape[2], filter_shape[3]]
del filter_tensor.dims[:]
filter_tensor.dims.extend(new_shape)
else:
op.type = MaceOp.Deconv2D.name
group_arg = op.arg.add()
group_arg.name = MaceKeyword.mace_group_str
group_arg.i = group_val
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
if 'dilations' in node.attrs:
dilation_val = node.attrs["dilations"]
else:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
mace_check(dilation_val == [1, 1],
"not support convtranspose with dilation != 1 yet.")
mace_check('output_padding' not in node.attrs,
"not support convtranspose with output_padding yet.")
mace_check('output_shape' not in node.attrs,
"not support convtranspose with output_shape yet.")
# TODO: if output shape specified, calculate padding value
# if 'output_padding' in node.attrs:
# output_padding = node.attrs['output_padding']
# output_padding_arg = op.arg.add()
# output_padding_arg.name = MaceKeyword.mace_output_padding_str
# output_padding_arg.ints.extend(output_padding)
# if 'output_shape' in node.attrs:
# output_shape = node.attrs['output_shape']
# output_shape_arg = op.arg.add()
# output_shape_arg.name = MaceKeyword.mace_output_shape_str
# output_shape_arg.ints.extend(output_shape)
def convert_nop(self, node):
pass
def convert_identity(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Identity.name
def convert_pad(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Pad.name
if 'pads' in node.attrs:
paddings_arg = op.arg.add()
paddings_arg.name = MaceKeyword.mace_paddings_str
paddings_value = node.attrs['pads']
paddings_arg.ints.extend(paddings_value)
if 'value' in node.attrs:
constant_value_arg = op.arg.add()
constant_value_arg.name = MaceKeyword.mace_constant_value_str
constant_value_arg.i = node.attrs['value']
def convert_gather(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Gather.name
if 'axis' in node.attrs:
value = node.attrs['axis']
else:
value = 0
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = value
def convert_split(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Split.name
if 'axis' in node.attrs:
value = node.attrs['axis']
else:
value = 0
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = value
def convert_transpose(self, node):
op = self.convert_general_op(node)
op.type = MaceOp.Transpose.name
if np.array_equal(perm, ordered_perm):
op.type = MaceOp.Identity.name
del op.input[1:]
if 'perm' in node.attrs:
perm = node.attrs['perm']
ordered_perm = np.sort(perm)
if np.array_equal(perm, ordered_perm):
op.type = MaceOp.Identity.name
else:
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend(perm)
@staticmethod
def squeeze_shape(shape, axis):
new_shape = []
if len(axis) > 0:
for i in range(len(shape)):
if i not in axis:
new_shape.append(shape[i])
else:
new_shape = shape
return new_shape
def convert_squeeze(self, node):
axis_value = node.attrs['axes']
if node.inputs[0] in self._consts:
tensor = self._consts[node.inputs[0]]
shape = tensor.dims
new_shape = self.squeeze_shape(shape, axis_value)
del tensor.dims[:]
tensor.dims.extend(new_shape)
self.remove_node(node)
else:
op = self.convert_general_op(node)
op.type = MaceOp.Squeeze.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
if 'axis' in node.attrs:
axis_value = node.attrs['axis']
else:
axis_value = []
axis_arg.ints.extend(axis_value)
@staticmethod
def transpose_const(tensor):
shape = tensor.dims
mace_check(len(shape) == 2, "gemm only supports 2-dim input.")
tensor_data = np.array(tensor.float_data).reshape(
shape[0], shape[1])
tensor_data = tensor_data.transpose(1, 0)
tensor.float_data[:] = tensor_data.flat
tensor.dims[:] = tensor_data.shape
def convert_fully_connected(self, node):
trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0
trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0
shape_a = self._graph_shapes_dict[node.inputs[0]]
shape_b = self._graph_shapes_dict[node.inputs[1]]
mace_check(trans_a == 0 and trans_b == 1,
"Do not support non-default transpose")
mace_check(len(shape_a) == 4,
"Unexpected fc input ndim.")
mace_check(node.inputs[1] in self._consts, "unexpect fc weight.")
if len(shape_b) == 4:
mace_check(list(shape_b[2:]) == [1, 1],
"Only support 4D weight with shape [*, *, 1, 1]")
elif len(shape_b) == 2:
tensor_b = self._consts[node.inputs[1]]
tensor_data = np.array(tensor_b.float_data).reshape(
shape_b[0], shape_b[1], 1, 1)
tensor_b.float_data[:] = tensor_data.flat
tensor_b.dims[:] = tensor_data.shape
else:
mace_check(False, "Unexpected fc weigth ndim.")
op = self._mace_net_def.op.add()
op.name = node.name
op.type = MaceOp.FullyConnected.name
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
framework_type_arg = op.arg.add()
framework_type_arg.name = MaceKeyword.mace_framework_type_str
framework_type_arg.i = FrameworkType.ONNX.value
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
for input in node.inputs:
op.input.append(input)
for output in node.outputs:
op.output.append(output)
output_shape = op.output_shape.add()
shape_info = self._graph_shapes_dict[output]
mace_check(len(shape_info) in [2, 4],
"gemm output shape should be 2 or 4 dims.")
if len(shape_info) == 4:
mace_check(shape_info[2] == 1 and shape_info[3] == 1,
"gemm's 4-dim output shape should be [*, * , 1, 1]")
else:
shape_info = [shape_info[0], shape_info[1], 1, 1]
output_shape.dims.extend(shape_info)
return op
......@@ -26,6 +26,7 @@ from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ActivationType
from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import DataFormat
from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
......@@ -465,15 +466,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
"Mace only supports dilation == 1 conv2d_transpose.")
mace_check(len(tf_op.inputs) >= 3,
"deconv should have (>=) 3 inputs.")
output_shape_arg = op.arg.add()
output_shape_arg.name = MaceKeyword.mace_output_shape_str
# if tf_op.inputs[0].op.type == TFOpType.Const.name:
# output_shape_value = \
# tf_op.inputs[0].eval().astype(np.int32).flat
# output_shape_arg.ints.extend(output_shape_value)
# else:
# output_shape_value = {}
# output_shape_arg.ints.extend(output_shape_value)
del op.input[:]
op.input.extend([tf_op.inputs[2].name,
tf_op.inputs[1].name,
......@@ -810,7 +802,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
del op.input[1:]
op.type = MaceOp.ReduceMean.name
op.type = MaceOp.Reduce.name
reduce_type_arg = op.arg.add()
reduce_type_arg.name = MaceKeyword.mace_reduce_type_str
reduce_type_arg.i = ReduceType.MEAN
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
if len(tf_op.inputs) > 1:
......
......@@ -352,21 +352,26 @@ class Transformer(base_converter.ConverterInterface):
if elttype == EltwiseType.SQR_DIFF.value and\
self.consumer_count(op.output[0]) == 1:
consumer_op = self._consumers[op.output[0]][0]
axis = ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_axis_str).ints
keep_dims = ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_keepdims_str).i
if consumer_op.type == MaceOp.ReduceMean.name and\
len(consumer_op.input) == 1 and \
axis[0] == 1 and axis[1] == 2 and keep_dims != 0:
print("Fold SquaredDiff ReduceMean: %s" % op.name)
op.type = MaceOp.SqrDiffMean.name
op.output[0] = consumer_op.output[0]
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
if consumer_op.type == MaceOp.Reduce.name:
axis = ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_axis_str).ints
keep_dims = ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_keepdims_str).i
reduce_type = ConverterUtil.get_arg(
consumer_op,
MaceKeyword.mace_reduce_type_str).i
if reduce_type == ReduceType.MEAN and\
len(consumer_op.input) == 1 and\
axis[0] == 1 and axis[1] == 2 and\
keep_dims > 0:
print("Fold SquaredDiff Reduce: %s" % op.name)
op.type = MaceOp.SqrDiffMean.name
op.output[0] = consumer_op.output[0]
self.replace_quantize_info(op, consumer_op)
self.safe_remove_node(consumer_op, op)
return True
return False
......@@ -1005,13 +1010,13 @@ class Transformer(base_converter.ConverterInterface):
'only support squeeze at at [2, 3]')
arg.ints[:] = [1, 2]
elif op.type == MaceOp.ReduceMean.name:
elif op.type == MaceOp.Reduce.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(
op) == DataFormat.NCHW \
and self._target_data_format == DataFormat.NHWC: # noqa
print("Transpose reduce mean args: %s(%s)"
print("Transpose reduce args: %s(%s)"
% (op.name, op.type))
reduce_axises = list(arg.ints)
new_axises = []
......
......@@ -48,7 +48,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pad.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling_buffer.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/reduce_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/reduce.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/resize_bicubic.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/resize_bilinear.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/split.cl"))
......
......@@ -362,6 +362,7 @@ class YAMLKeyword(object):
validation_threshold = 'validation_threshold'
graph_optimize_options = 'graph_optimize_options' # internal use for now
cl_mem_type = 'cl_mem_type'
backend = 'backend'
################################
......
......@@ -55,6 +55,7 @@ ModelFormatStrs = [
PlatformTypeStrs = [
"tensorflow",
"caffe",
"onnx",
]
PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs],
type=str)
......@@ -469,6 +470,10 @@ def format_model_config(flags):
else:
subgraph[YAMLKeyword.validation_inputs_data] = \
validation_inputs_data
onnx_backend = subgraph.get(
YAMLKeyword.backend, "tensorflow")
subgraph[YAMLKeyword.backend] = onnx_backend
input_ranges = subgraph.get(
YAMLKeyword.input_ranges, [])
if not isinstance(input_ranges, list):
......
......@@ -572,7 +572,8 @@ class DeviceWrapper:
YAMLKeyword.input_data_types],
caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][
YAMLKeyword.validation_threshold][validate_type]
YAMLKeyword.validation_threshold][validate_type],
backend=subgraphs[0][YAMLKeyword.backend]
)
if flags.report and flags.round > 0:
tuned = is_tuned and device_type == DeviceType.GPU
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnx
import sys
from onnx import optimizer
# Usage: python onnx_optimizer.py model.onnx model_opt.onnx
def main():
if len(sys.argv) != 3:
print "Usage: python onnx_optimizer.py model.onnx model_opt.onnx"
sys.exit(0)
in_path = sys.argv[1]
out_path = sys.argv[2]
original_model = onnx.load(in_path)
print "Start optimize ONNX model for inference:"
passes = ['eliminate_identity',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
'eliminate_nop_pad',
'eliminate_nop_transpose',
'eliminate_unused_initializer',
'extract_constant_to_initializer',
'fuse_add_bias_into_conv',
'fuse_bn_into_conv',
'fuse_transpose_into_gemm']
for i in range(len(passes)):
print i, ".", passes[i]
optimized_model = optimizer.optimize(original_model, passes)
onnx.save_model(optimized_model, out_path)
print "Optimize Finished!"
print "Please check new model in:", out_path
if __name__ == '__main__':
main()
......@@ -621,7 +621,8 @@ def validate_model(abi,
caffe_env,
input_file_name="model_input",
output_file_name="model_out",
validation_threshold=0.9):
validation_threshold=0.9,
backend="tensorflow"):
six.print_("* Validate with %s" % platform)
if abi != "host":
for output_name in output_nodes:
......@@ -638,7 +639,14 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types))
validation_threshold, ",".join(input_data_types), backend)
elif platform == "onnx":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types), backend)
elif platform == "caffe":
image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator"
......@@ -654,7 +662,7 @@ def validate_model(abi,
device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes),
validation_threshold, ",".join(input_data_types))
validation_threshold, ",".join(input_data_types), backend)
elif caffe_env == common.CaffeEnvType.DOCKER:
docker_image_id = sh.docker("images", "-q", image_name)
if not docker_image_id:
......@@ -720,6 +728,7 @@ def validate_model(abi,
"--output_shape=%s" % ":".join(output_shapes),
"--validation_threshold=%f" % validation_threshold,
"--input_data_type=%s" % ",".join(input_data_types),
"--backend=%s" % ",".join(backend),
_fg=True)
six.print_("Validation done!\n")
......
......@@ -21,6 +21,10 @@ import re
import common
import onnx
from onnx import helper
from onnx import TensorProto
# Validation Flow:
# 1. Generate input data
# 2. Use mace_run to run model on phone.
......@@ -190,9 +194,64 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
value, validation_threshold)
def validate_onnx_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
output_names, output_shapes, validation_threshold,
input_data_types, backend):
if backend == "tensorflow":
from onnx_tf.backend import prepare
print "valivate on onnx tensorflow backend."
elif backend == "caffe2" or backend == "pytorch":
from caffe2.python.onnx.backend import prepare
print "valivate on onnx caffe2 backend."
else:
common.MaceLogger.error(
VALIDATION_MODULE,
"onnx backend framwork '" + backend + "' is invalid.")
if not os.path.isfile(model_file):
common.MaceLogger.error(
VALIDATION_MODULE,
"Input graph file '" + model_file + "' does not exist!")
model = onnx.load(model_file)
input_dict = {}
for i in range(len(input_names)):
input_value = load_data(common.formatted_file_name(input_file,
input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1,
2))
input_dict[input_names[i]] = input_value
onnx_outputs = []
for i in range(len(output_names)):
out_shape = output_shapes[i]
if len(out_shape) == 4:
out_shape[1], out_shape[2], out_shape[3] = \
out_shape[3], out_shape[1], out_shape[2]
onnx_outputs.append(
helper.make_tensor_value_info(output_names[i],
TensorProto.FLOAT,
out_shape))
model.graph.output.extend(onnx_outputs)
rep = prepare(model)
output_values = rep.run(input_dict)
for i in range(len(output_names)):
out_name = output_names[i]
value = output_values[out_name].flatten()
out_shape = output_shapes[i]
if len(out_shape) == 4:
value = value.reshape(out_shape).transpose((0, 2, 3, 1))
output_file_name = common.formatted_file_name(mace_out_file,
output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, device_type, output_names[i],
mace_out_value, value,
validation_threshold)
def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node,
validation_threshold, input_data_type):
validation_threshold, input_data_type, backend):
input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')]
......@@ -217,6 +276,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes,
validation_threshold)
elif platform == 'onnx':
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
validate_onnx_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
output_names, output_shapes,
validation_threshold,
input_data_types, backend)
def parse_args():
......@@ -259,6 +327,11 @@ def parse_args():
parser.add_argument(
"--validation_threshold", type=float, default=0.995,
help="validation similarity threshold")
parser.add_argument(
"--backend",
type=str,
default="tensorflow",
help="onnx backend framwork")
return parser.parse_known_args()
......@@ -276,4 +349,5 @@ if __name__ == '__main__':
FLAGS.input_node,
FLAGS.output_node,
FLAGS.validation_threshold,
FLAGS.input_data_type)
FLAGS.input_data_type,
FLAGS.backend)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册