diff --git a/README.md b/README.md index 60aa3180931e45e8cbcaa8f71993849e91fe5d61..ed119e31a08ee7236c594441ce45ba36903ce2f4 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,8 @@ the following projects during the development: [Caffe](https://github.com/BVLC/caffe), [SNPE](https://developer.qualcomm.com/software/snapdragon-neural-processing-engine-ai), [ARM ComputeLibrary](https://github.com/ARM-software/ComputeLibrary), - [ncnn](https://github.com/Tencent/ncnn) and many others: we learned many best + [ncnn](https://github.com/Tencent/ncnn), + [ONNX](https://github.com/onnx/onnx) and many others: we learned many best practices from these projects. Finally, we also thank the Qualcomm, Pinecone and MediaTek engineering teams for diff --git a/docs/installation/env_requirement.rst b/docs/installation/env_requirement.rst index 12af0e61e8ea7705f4bf7c766bc897e7bd99df88..dac154194aef8be952376b769b7d540cf6a01111 100644 --- a/docs/installation/env_requirement.rst +++ b/docs/installation/env_requirement.rst @@ -64,6 +64,9 @@ Optional dependencies * - FileLock - pip install -I filelock==3.0.0 - Required by run on Android + * - ONNX + - pip install onnx + - Required by ONNX model .. note:: diff --git a/docs/installation/manual_setup.rst b/docs/installation/manual_setup.rst index fc130301386353898ab7eecf0ef12347ea99b2cb..764ae9c49e3fd79af99302ee0606253e6409035a 100644 --- a/docs/installation/manual_setup.rst +++ b/docs/installation/manual_setup.rst @@ -72,3 +72,9 @@ Install Caffe (Optional) ------------------------- Please follow the installation instruction of `Caffe `__. + + +Install ONNX (Optional) +------------------------- + +Please follow the installation instruction of `ONNX `__. diff --git a/docs/introduction.rst b/docs/introduction.rst index 38aa930cca87d9e04d24bb72fee01b74415d349e..f8ba81e9202c33466943065bb97bb43d0774fbf5 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -18,8 +18,7 @@ MACE Model ~~~~~~~~~~ MACE defines a customized model format which is similar to -Caffe2. The MACE model can be converted from exported models by TensorFlow -and Caffe. +Caffe2. The MACE model can be converted from exported models by TensorFlow, Caffe or ONNX Model. MACE Interpreter ~~~~~~~~~~~~~~~~~ @@ -50,7 +49,7 @@ Build MACE dynamic or static libraries. 3. Convert model ~~~~~~~~~~~~~~~~~~ -Convert TensorFlow or Caffe model to MACE model. +Convert TensorFlow, Caffe or ONNX model to MACE model. 4.1. Deploy ~~~~~~~~~~~~~~~~~~ @@ -86,7 +85,7 @@ MACE覆盖了常见的移动端计算设备(CPU,GPU和DSP),并且提供 MACE Model ~~~~~~~~~~~~~~~~~~ -MACE定义了自有的模型格式(类似于Caffe2),通过MACE提供的工具可以将Caffe和TensorFlow的模型 +MACE定义了自有的模型格式(类似于Caffe2),通过MACE提供的工具可以将Caffe/TensorFlow/ONNX格式的模型 转为MACE模型。 MACE Interpreter @@ -118,7 +117,7 @@ CPU/GPU/DSP Runtime对应于各个计算设备的算子实现。 3. 转换模型 ~~~~~~~~~~~~~~~~~~ -将TensorFlow 或者 Caffe的模型转为MACE的模型。 +将TensorFlow或者Caffe或者ONNX的模型转为MACE的模型。 4.1. 部署 ~~~~~~~~~~~~~~~~~~ diff --git a/docs/user_guide/advanced_usage.rst b/docs/user_guide/advanced_usage.rst index 3cae38d7df6930f6512aa4ac8f1b2d7e0934720a..93ebb4f8d1c9f66d9aa600c1b063f2a6b8d488da 100644 --- a/docs/user_guide/advanced_usage.rst +++ b/docs/user_guide/advanced_usage.rst @@ -78,6 +78,8 @@ in one deployment file. - [optional] Specify Numpy validation inputs. When not provided, [-1, 1] random values will be used. * - validation_threshold - [optional] Specify the similarity threshold for validation. A dict with key in 'CPU', 'GPU' and/or 'HEXAGON' and value <= 1.0. + * - backend + - The onnx backend framework for validation, could be [tensorflow, caffe2, pytorch], default is tensorflow. * - runtime - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. * - data_type diff --git a/docs/user_guide/basic_usage.rst b/docs/user_guide/basic_usage.rst index ea886a233a1a956f793264051e7200df65847299..d4d404baf8d652f169fe029be2a4966880351dd6 100644 --- a/docs/user_guide/basic_usage.rst +++ b/docs/user_guide/basic_usage.rst @@ -114,6 +114,19 @@ MACE now supports models from TensorFlow and Caffe (more frameworks will be supp # Upgrade caffemodel $CAFFE_ROOT/build/tools/upgrade_net_proto_binary MODEL.caffemodel MODEL.new.caffemodel +- ONNX + + Prepare your ONNX model.onnx file. + + Use `ONNX Optimizer Tool `__ to optimize your model for inference. + This tool will improve the efficiency of inference like the `Graph Transform Tool `__ + in TensorFlow. + + .. code:: bash + + # Optimize your model + $python MACE_ROOT/tools/onnx_optimizer.py model.onnx model_opt.onnx + =========================================== 2. Create a deployment file for your model @@ -137,6 +150,12 @@ Modify one of them and use it for your own case. .. literalinclude:: models/demo_models_caffe.yml :language: yaml +- ONNX + + .. literalinclude:: models/demo_models_onnx.yml + :language: yaml + + More details about model deployment file are in :doc:`advanced_usage`. ====================== diff --git a/docs/user_guide/models/demo_models_onnx.yml b/docs/user_guide/models/demo_models_onnx.yml new file mode 100644 index 0000000000000000000000000000000000000000..d4e6cbc75f08a64a74432166b6673bbdf739032e --- /dev/null +++ b/docs/user_guide/models/demo_models_onnx.yml @@ -0,0 +1,42 @@ +# The name of library +library_name: mobilenet +target_abis: [arm64-v8a] +model_graph_format: file +model_data_format: file +models: + mobilenet_v1: # model tag, which will be used in model loading and must be specific. + platform: onnx + # path to your onnx model file. Support local path, http:// and https:// + model_file_path: https://cnbj1.fds.api.xiaomi.com/mace/miai-models/mobilenet-v1/mobilenet-v1-1.0.pb + # sha256_checksum of your model's onnx file. + # use this command to get the sha256_checksum: sha256sum path/to/your/pb/file + model_sha256_checksum: 71b10f540ece33c49a7b51f5d4095fc9bd78ce46ebf0300487b2ee23d71294e6 + # define your model's interface + # if there multiple inputs or outputs, write like blow: + # subgraphs: + # - input_tensors: + # - input0 + # - input1 + # input_shapes: + # - 1,224,224,3 + # - 1,224,224,3 + # output_tensors: + # - output0 + # - output1 + # output_shapes: + # - 1,1001 + # - 1,1001 + subgraphs: + - input_tensors: + - input + input_shapes: + - 1,224,224,3 + output_tensors: + - MobilenetV1/Predictions/Reshape_1 + output_shapes: + - 1,1001 + # onnx backend framwork for validation. Suppport pytorch/caffe/tensorflow. Default is tensorflow. + backend: tensorflow + # cpu, gpu or cpu+gpu + runtime: cpu+gpu + winograd: 0 \ No newline at end of file diff --git a/mace/ops/activation.h b/mace/ops/activation.h index 36fb45d6bdeef39eb9214d398a5cd33fea7c4a07..07051cc1cbf184d5f4ed3f9e5d2e2f35e77d01cd 100644 --- a/mace/ops/activation.h +++ b/mace/ops/activation.h @@ -32,7 +32,8 @@ enum ActivationType { RELUX = 2, PRELU = 3, TANH = 4, - SIGMOID = 5 + SIGMOID = 5, + LEAKYRELU = 6, }; inline ActivationType StringToActivationType(const std::string type) { @@ -48,6 +49,8 @@ inline ActivationType StringToActivationType(const std::string type) { return ActivationType::SIGMOID; } else if (type == "NOOP") { return ActivationType::NOOP; + } else if (type == "LEAKYRELU") { + return ActivationType ::LEAKYRELU; } else { LOG(FATAL) << "Unknown activation type: " << type; } @@ -90,6 +93,13 @@ void DoActivation(const T *input_ptr, output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i])); } break; + case LEAKYRELU: +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < size; ++i) { + output_ptr[i] = std::max(input_ptr[i], + static_cast(0)) * relux_max_limit; + } + break; default: LOG(FATAL) << "Unknown activation type: " << type; } @@ -122,6 +132,9 @@ inline void DoActivation(const float *input_ptr, output_ptr[i] = 1 / (1 + std::exp(-input_ptr[i])); } break; + case LEAKYRELU: + LeakyReluNeon(input_ptr, relux_max_limit, size, output_ptr); + break; default: LOG(FATAL) << "Unknown activation type: " << type; } diff --git a/mace/ops/argmax.cc b/mace/ops/argmax.cc index 2b3e2f0be6aa223ef4eb8d0c47aa5733bd13cac6..3d8b27fe91e752af855c2310e9c36c8659249515 100644 --- a/mace/ops/argmax.cc +++ b/mace/ops/argmax.cc @@ -27,18 +27,29 @@ template class ArgMaxOp : public Operation { public: explicit ArgMaxOp(OpConstructContext *context) - : Operation(context) {} + : Operation(context), + axis_(Operation::GetOptionalArg("axis", 0)), + keep_dims_(Operation::GetOptionalArg("keepdims", true)), + argmin_(Operation::GetOptionalArg("argmin", false)) {} MaceStatus Run(OpContext *context) override { MACE_UNUSED(context); const Tensor *input = this->Input(0); - const Tensor *axis = this->Input(1); + const Tensor *axis = this->InputSize() == 2 ? + this->Input(1) : nullptr; Tensor *output = this->Output(0); + MACE_CHECK(keep_dims_, "Mace only supports keep_dims ArgMax."); MACE_CHECK(input->dim_size() > 0, "ArgMax input should not be a scalar"); - MACE_CHECK(axis->dim_size() == 0, "Mace argmax only supports scalar axis"); - Tensor::MappingGuard axis_guard(axis); - int axis_value = axis->data()[0]; + int axis_value = 0; + if (axis != nullptr) { + MACE_CHECK(axis->dim_size() == 0, + "Mace argmax only supports scalar axis"); + Tensor::MappingGuard axis_guard(axis); + axis_value = axis->data()[0]; + } else { + axis_value = axis_; + } if (axis_value < 0) { axis_value += input->dim_size(); } @@ -59,22 +70,43 @@ class ArgMaxOp : public Operation { index_t outer_size = output->size(); index_t inner_size = input->dim(axis_value); + if (argmin_) { #pragma omp parallel for schedule(runtime) - for (index_t i = 0; i < outer_size; ++i) { - int idx = 0; - T max_value = std::numeric_limits::lowest(); - const T *input_ptr = input_data + i * inner_size; - for (index_t j = 0; j < inner_size; ++j) { - if (input_ptr[j] > max_value) { - max_value = input_ptr[j]; - idx = j; + for (index_t i = 0; i < outer_size; ++i) { + int idx = 0; + T min_value = std::numeric_limits::max(); + const T *input_ptr = input_data + i * inner_size; + for (index_t j = 0; j < inner_size; ++j) { + if (input_ptr[j] < min_value) { + min_value = input_ptr[j]; + idx = j; + } } + output_data[i] = idx; + } + } else { +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < outer_size; ++i) { + int idx = 0; + T max_value = std::numeric_limits::lowest(); + const T *input_ptr = input_data + i * inner_size; + for (index_t j = 0; j < inner_size; ++j) { + if (input_ptr[j] > max_value) { + max_value = input_ptr[j]; + idx = j; + } + } + output_data[i] = idx; } - output_data[i] = idx; } return MaceStatus::MACE_SUCCESS; } + + protected: + const int axis_; + bool keep_dims_; + bool argmin_; }; diff --git a/mace/ops/arm/activation_neon.cc b/mace/ops/arm/activation_neon.cc index ec9ba357425ac9c6603b08bac604b6d7f79c57f4..c9cbc3c10c76c626d27a1b1c6c9c0d2ef543c185 100644 --- a/mace/ops/arm/activation_neon.cc +++ b/mace/ops/arm/activation_neon.cc @@ -67,5 +67,29 @@ void ReluxNeon(const float *input, const float limit, #endif } +void LeakyReluNeon(const float *input, const float alpha, + const index_t size, float *output) { +#if defined(MACE_ENABLE_NEON) + float32x4_t vzero = vdupq_n_f32(0.f); + float32x4_t valpha = vdupq_n_f32(alpha); +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i <= size - 4; i += 4) { + float32x4_t v = vld1q_f32(input + i); + v = vmaxq_f32(v, vzero); + v = vmulq_f32(v, valpha); + vst1q_f32(output + i, v); + } + // remain + for (index_t i = (size >> 2) << 2; i < size; ++i) { + output[i] = std::max(input[i], 0.f) * alpha; + } +#else +#pragma omp parallel for schedule(runtime) + for (index_t i = 0; i < size; ++i) { + output[i] = std::max(input[i], 0.f) * alpha; + } +#endif +} + } // namespace ops } // namespace mace diff --git a/mace/ops/arm/activation_neon.h b/mace/ops/arm/activation_neon.h index cbd1974f22c9da7099b1dca95c67e75361d6851c..e1e24a71264c91cbd1a486a61c97fbdc4d4d19cc 100644 --- a/mace/ops/arm/activation_neon.h +++ b/mace/ops/arm/activation_neon.h @@ -25,6 +25,9 @@ void ReluNeon(const float *input, const index_t size, float *output); void ReluxNeon(const float *input, const float limit, const index_t size, float *output); +void LeakyReluNeon(const float *input, const float alpha, + const index_t size, float *output); + } // namespace ops } // namespace mace diff --git a/mace/ops/opencl/buffer/pooling.h b/mace/ops/opencl/buffer/pooling.h index de7d76108fd40e65cb745aa4172adcc993cc6302..4b3dbd1b8780c402b14cd6bac17e1819666bc6da 100644 --- a/mace/ops/opencl/buffer/pooling.h +++ b/mace/ops/opencl/buffer/pooling.h @@ -43,6 +43,7 @@ class PoolingKernel : public OpenCLPoolingKernel { const Padding &padding_type, const std::vector &padding_data, const int *dilations, + const RoundType round_type, Tensor *output) override; private: @@ -62,6 +63,7 @@ MaceStatus PoolingKernel::Compute( const Padding &padding_type, const std::vector &padding_data, const int *dilations, + const RoundType round_type, Tensor *output) { MACE_CHECK(dilations[0] == 1 && dilations[1] == 1) << "Pooling opencl kernel not support dilation yet"; @@ -82,7 +84,7 @@ MaceStatus PoolingKernel::Compute( } else { paddings = padding_data; CalcOutputSize(input->shape().data(), filter_shape.data(), - padding_data.data(), dilations, strides, RoundType::CEIL, + padding_data.data(), dilations, strides, round_type, output_shape.data()); } diff --git a/mace/ops/opencl/cl/common.h b/mace/ops/opencl/cl/common.h index 069130d4d18a67022ede6b55e86e0b9880c87347..29054ad3751ad62c322e6ed793d3742b081eab9a 100644 --- a/mace/ops/opencl/cl/common.h +++ b/mace/ops/opencl/cl/common.h @@ -102,6 +102,9 @@ inline DATA_TYPE4 do_activation(DATA_TYPE4 in, #endif #ifdef USE_SIGMOID out = do_sigmoid(in); +#endif +#ifdef USE_LEAKYRELU + out = fmax(in, (DATA_TYPE)0) * relux_max_limit; #endif return out; } diff --git a/mace/ops/opencl/cl/reduce.cl b/mace/ops/opencl/cl/reduce.cl new file mode 100644 index 0000000000000000000000000000000000000000..92afeb6e9e102cf67dfe18d715d104e8378845e8 --- /dev/null +++ b/mace/ops/opencl/cl/reduce.cl @@ -0,0 +1,104 @@ +#include + +__kernel void reduce(OUT_OF_RANGE_PARAMS + GLOBAL_WORK_GROUP_SIZE_DIM3 + __read_only image2d_t input, + __local float4 *group_sum, + __private const int group_size, + __private const int partial_len, + __private const int remain_index, + __private const int batch, + __private const int in_height, + __private const int in_width, + __private const float image_size_reciprocal, + __private const int channel_blocks, + __write_only image2d_t output) { + const int i = get_local_id(0); + const int j = get_local_id(1); + const int k = get_global_id(2); + +#ifndef NON_UNIFORM_WORK_GROUP + if (k >= global_size_dim2) + return; +#endif + const int dim0_size = get_local_size(0); + const int index = mad24(j, dim0_size, i); + const int b = k / channel_blocks; + const int ch = mad24(b, -channel_blocks, k); + + DATA_TYPE4 in; + +#if REDUCE_TYPE == 1 + float4 tmp = (float4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT}; +#elif REDUCE_TYPE == 2 + float4 tmp = (float4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT}; +#elif REDUCE_TYPE == 3 + float4 tmp = (float4){1, 1, 1, 1}; +#else + float4 tmp = (float4){0, 0, 0, 0}; +#endif + + const int valid_part_len = select(partial_len, + partial_len - 1, + remain_index > 0 && index >= remain_index); + const int full_offset = mul24(index, partial_len); + const int base_offset = select(full_offset, + full_offset - (index - remain_index), + valid_part_len < partial_len); +#pragma unroll + for (int l = 0; l < valid_part_len; ++l) { + int offset = base_offset + l; + int h_id = offset / in_width; + int w_id = mad24(h_id, -in_width, offset); + int pos_x = mad24(ch, in_width, w_id); + int pos_y = mad24(b, in_height, h_id); + in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); +// MIN +#if REDUCE_TYPE == 1 + tmp = fmin(tmp, in); +// MAX +#elif REDUCE_TYPE == 2 + tmp = fmax(tmp, in); +// PROD +#elif REDUCE_TYPE == 3 + tmp = tmp * in; +// MEAN +#else + tmp = tmp + in; +#endif + } + +#if REDUCE_TYPE == 0 + tmp = tmp * image_size_reciprocal; +#endif + group_sum[index] = tmp; + +#ifdef NON_QUALCOMM_ADRENO + barrier(CLK_LOCAL_MEM_FENCE); +#endif + + if (i == 0 && j == 0) { +#if REDUCE_TYPE == 1 + DATA_TYPE4 out = (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT}; +#elif REDUCE_TYPE == 2 + DATA_TYPE4 out = (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT}; +#elif REDUCE_TYPE == 3 + DATA_TYPE4 out = (DATA_TYPE4){1, 1, 1, 1}; +#else + DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0}; +#endif +#pragma unroll + for (int l = 0; l < group_size; ++l) { +#if REDUCE_TYPE == 1 + out = fmin(out, group_sum[l]); +#elif REDUCE_TYPE == 2 + out = fmax(out, group_sum[l]); +#elif REDUCE_TYPE == 3 + out = out * group_sum[l]; +#else + out = out + group_sum[l]; +#endif + } + WRITE_IMAGET(output, (int2)(ch, b), out); + } +} diff --git a/mace/ops/opencl/cl/reduce_mean.cl b/mace/ops/opencl/cl/reduce_mean.cl deleted file mode 100644 index c2810f4876e57f3fb82836c231ed4277ff186055..0000000000000000000000000000000000000000 --- a/mace/ops/opencl/cl/reduce_mean.cl +++ /dev/null @@ -1,62 +0,0 @@ -#include - -__kernel void reduce_mean(OUT_OF_RANGE_PARAMS - GLOBAL_WORK_GROUP_SIZE_DIM3 - __read_only image2d_t input, - __local float4 *group_sum, - __private const int group_size, - __private const int partial_len, - __private const int remain_index, - __private const int batch, - __private const int in_height, - __private const int in_width, - __private const float image_size_reciprocal, - __private const int channel_blocks, - __write_only image2d_t output) { - const int i = get_local_id(0); - const int j = get_local_id(1); - const int k = get_global_id(2); - -#ifndef NON_UNIFORM_WORK_GROUP - if (k >= global_size_dim2) - return; -#endif - const int dim0_size = get_local_size(0); - float4 tmp = (float4){0, 0, 0, 0}; - const int index = mad24(j, dim0_size, i); - const int b = k / channel_blocks; - const int ch = mad24(b, -channel_blocks, k); - - DATA_TYPE4 in; - const int valid_part_len = select(partial_len, - partial_len - 1, - remain_index > 0 && index >= remain_index); - const int full_offset = mul24(index, partial_len); - const int base_offset = select(full_offset, - full_offset - (index - remain_index), - valid_part_len < partial_len); -#pragma unroll - for (int l = 0; l < valid_part_len; ++l) { - int offset = base_offset + l; - int h_id = offset / in_width; - int w_id = mad24(h_id, -in_width, offset); - int pos_x = mad24(ch, in_width, w_id); - int pos_y = mad24(b, in_height, h_id); - in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); - tmp = tmp + in; - } - group_sum[index] = tmp * image_size_reciprocal; - -#ifdef NON_QUALCOMM_ADRENO - barrier(CLK_LOCAL_MEM_FENCE); -#endif - - if (i == 0 && j == 0) { - DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0}; -#pragma unroll - for (int l = 0; l < group_size; ++l) { - out = out + group_sum[l]; - } - WRITE_IMAGET(output, (int2)(ch, b), out); - } -} diff --git a/mace/ops/opencl/image/activation.h b/mace/ops/opencl/image/activation.h index 80713c36977b495ae857f0af75c031a424c933ea..e8448fe0016d218f2e986fcad3627c6aef17c9b3 100644 --- a/mace/ops/opencl/image/activation.h +++ b/mace/ops/opencl/image/activation.h @@ -99,6 +99,10 @@ MaceStatus ActivationKernel::Compute( tuning_key_prefix_ = "sigmoid_opencl_kernel"; built_options.emplace("-DUSE_SIGMOID"); break; + case LEAKYRELU: + tuning_key_prefix_ = "leakyrelu_opencl_kernel"; + built_options.emplace("-DUSE_LEAKYRELU"); + break; default: LOG(FATAL) << "Unknown activation type: " << activation_; } diff --git a/mace/ops/opencl/image/pooling.h b/mace/ops/opencl/image/pooling.h index 1af677403bfa3160aedf8266bc24cf45baf04b37..1d9a4df647571e113fe043f2e7555c72084e08fd 100644 --- a/mace/ops/opencl/image/pooling.h +++ b/mace/ops/opencl/image/pooling.h @@ -69,6 +69,7 @@ class PoolingKernel : public OpenCLPoolingKernel { const Padding &padding_type, const std::vector &padding_data, const int *dilations, + const RoundType round_type, Tensor *output) override; private: @@ -87,6 +88,7 @@ MaceStatus PoolingKernel::Compute( const Padding &padding_type, const std::vector &padding_data, const int *dilations, + const RoundType round_type, Tensor *output) { MACE_CHECK(dilations[0] == 1 && dilations[1] == 1) << "Pooling opencl kernel not support dilation yet"; @@ -103,7 +105,7 @@ MaceStatus PoolingKernel::Compute( } else { paddings = padding_data; CalcOutputSize(input->shape().data(), filter_shape.data(), - padding_data.data(), dilations, strides, RoundType::CEIL, + padding_data.data(), dilations, strides, round_type, output_shape.data()); } diff --git a/mace/ops/opencl/image/reduce_mean.h b/mace/ops/opencl/image/reduce.h similarity index 86% rename from mace/ops/opencl/image/reduce_mean.h rename to mace/ops/opencl/image/reduce.h index 3280691c9d303d08048bcb23bce2ab040c72b9e7..98368f913048f193c98ec6ca4db0e7787f531e3b 100644 --- a/mace/ops/opencl/image/reduce_mean.h +++ b/mace/ops/opencl/image/reduce.h @@ -11,10 +11,10 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_ -#define MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_ +#ifndef MACE_OPS_OPENCL_IMAGE_REDUCE_H_ +#define MACE_OPS_OPENCL_IMAGE_REDUCE_H_ -#include "mace/ops/opencl/reduce_mean.h" +#include "mace/ops/opencl/reduce.h" #include #include @@ -24,6 +24,7 @@ #include "mace/core/op_context.h" #include "mace/core/tensor.h" #include "mace/ops/opencl/helper.h" +#include "mace/ops/reduce.h" namespace mace { namespace ops { @@ -31,11 +32,12 @@ namespace opencl { namespace image { template -class ReduceMeanKernel : public OpenCLReduceMeanKernel { +class ReduceKernel : public OpenCLReduceKernel { public: - ReduceMeanKernel(const std::vector axis, - const bool keep_dims) - : axis_(axis), keep_dims_(keep_dims) {} + ReduceKernel(ReduceType type, + const std::vector axis, + const bool keep_dims) + : reduce_type_(type), axis_(axis), keep_dims_(keep_dims) {} MaceStatus Compute( OpContext *context, @@ -43,6 +45,7 @@ class ReduceMeanKernel : public OpenCLReduceMeanKernel { Tensor *output) override; private: + ReduceType reduce_type_; const std::vector axis_; bool keep_dims_; cl::Kernel kernel_; @@ -51,16 +54,16 @@ class ReduceMeanKernel : public OpenCLReduceMeanKernel { }; template -MaceStatus ReduceMeanKernel::Compute( +MaceStatus ReduceKernel::Compute( OpContext *context, const Tensor *input, Tensor *output) { MACE_CHECK_NOTNULL(input); MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); MACE_CHECK(input->dim_size() == 4, - "reduce mean gpu only support 4-dim input"); + "reduce gpu only support 4-dim input"); MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2, - "reduce mean gpu only support 1,2-axis reduce"); + "reduce gpu only support 1,2-axis reduce"); index_t batch = input->dim(0); const index_t in_height = input->dim(1); const index_t in_width = input->dim(2); @@ -84,14 +87,15 @@ MaceStatus ReduceMeanKernel::Compute( std::set built_options; MACE_OUT_OF_RANGE_CONFIG; MACE_NON_UNIFORM_WG_CONFIG; - std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce_mean"); - built_options.emplace("-Dreduce_mean=" + kernel_name); + std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce"); + built_options.emplace("-Dreduce=" + kernel_name); built_options.emplace("-DDATA_TYPE=" + DtToUpCompatibleCLDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpCompatibleCLCMDDt(dt)); + built_options.emplace(MakeString("-DREDUCE_TYPE=", reduce_type_)); if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) { built_options.emplace("-DNON_QUALCOMM_ADRENO"); } - MACE_RETURN_IF_ERROR(runtime->BuildKernel("reduce_mean", + MACE_RETURN_IF_ERROR(runtime->BuildKernel("reduce", kernel_name, built_options, &kernel_)); @@ -170,4 +174,4 @@ MaceStatus ReduceMeanKernel::Compute( } // namespace ops } // namespace mace -#endif // MACE_OPS_OPENCL_IMAGE_REDUCE_MEAN_H_ +#endif // MACE_OPS_OPENCL_IMAGE_REDUCE_H_ diff --git a/mace/ops/opencl/pooling.h b/mace/ops/opencl/pooling.h index fc41a4746ae47c521694be237bb66cf11e15dd72..d24669a5215c7e37adeb68212c5ed0d91b036cd3 100644 --- a/mace/ops/opencl/pooling.h +++ b/mace/ops/opencl/pooling.h @@ -36,6 +36,7 @@ class OpenCLPoolingKernel { const Padding &padding_type, const std::vector &padding_data, const int *dilations, + const RoundType round_type, Tensor *output) = 0; MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLPoolingKernel); }; diff --git a/mace/ops/opencl/reduce_mean.h b/mace/ops/opencl/reduce.h similarity index 81% rename from mace/ops/opencl/reduce_mean.h rename to mace/ops/opencl/reduce.h index 9e279a2a12596a3b40998373f541cf8c91b463e9..32649ab3c790ea51b900eaa41c6613a6afc1b878 100644 --- a/mace/ops/opencl/reduce_mean.h +++ b/mace/ops/opencl/reduce.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -#ifndef MACE_OPS_OPENCL_REDUCE_MEAN_H_ -#define MACE_OPS_OPENCL_REDUCE_MEAN_H_ +#ifndef MACE_OPS_OPENCL_REDUCE_H_ +#define MACE_OPS_OPENCL_REDUCE_H_ #include "mace/public/mace.h" #include "mace/utils/utils.h" @@ -24,16 +24,16 @@ class OpContext; class Tensor; namespace ops { -class OpenCLReduceMeanKernel { +class OpenCLReduceKernel { public: virtual MaceStatus Compute( OpContext *context, const Tensor *input, Tensor *output) = 0; - MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLReduceMeanKernel); + MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLReduceKernel); }; } // namespace ops } // namespace mace -#endif // MACE_OPS_OPENCL_REDUCE_MEAN_H_ +#endif // MACE_OPS_OPENCL_REDUCE_H_ diff --git a/mace/ops/ops_registry.cc b/mace/ops/ops_registry.cc index 7407683d6464ea2559eca1d55ee548bd4e3c75dc..cd958705a094794ce92d194c30d8a83da906a716 100644 --- a/mace/ops/ops_registry.cc +++ b/mace/ops/ops_registry.cc @@ -44,7 +44,7 @@ extern void RegisterLocalResponseNorm(OpRegistryBase *op_registry); extern void RegisterMatMul(OpRegistryBase *op_registry); extern void RegisterPad(OpRegistryBase *op_registry); extern void RegisterPooling(OpRegistryBase *op_registry); -extern void RegisterReduceMean(OpRegistryBase *op_registry); +extern void RegisterReduce(OpRegistryBase *op_registry); extern void RegisterReshape(OpRegistryBase *op_registry); extern void RegisterResizeBicubic(OpRegistryBase *op_registry); extern void RegisterResizeBilinear(OpRegistryBase *op_registry); @@ -102,7 +102,7 @@ OpRegistry::OpRegistry() : OpRegistryBase() { ops::RegisterMatMul(this); ops::RegisterPad(this); ops::RegisterPooling(this); - ops::RegisterReduceMean(this); + ops::RegisterReduce(this); ops::RegisterReshape(this); ops::RegisterResizeBicubic(this); ops::RegisterResizeBilinear(this); diff --git a/mace/ops/pooling.cc b/mace/ops/pooling.cc index 50372c3cf1f1603d80eec28bce0d701535b9467d..9228548718befc6720c87bf38d93a264d754ad21 100644 --- a/mace/ops/pooling.cc +++ b/mace/ops/pooling.cc @@ -43,11 +43,14 @@ class PoolingOpBase : public ConvPool2dOpBase { kernels_(Operation::GetRepeatedArgs("kernels")), pooling_type_( static_cast(Operation::GetOptionalArg( - "pooling_type", static_cast(AVG)))) {} + "pooling_type", static_cast(AVG)))), + round_type_(static_cast(Operation::GetOptionalArg( + "round_mode", static_cast(CEIL)))) {} protected: std::vector kernels_; PoolingType pooling_type_; + RoundType round_type_; MACE_OP_INPUT_TAGS(INPUT); MACE_OP_OUTPUT_TAGS(OUTPUT); @@ -82,7 +85,7 @@ class PoolingOp : public PoolingOpBase { paddings_.data(), dilations_.data(), strides_.data(), - RoundType::CEIL, + round_type_, output_shape.data()); } MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape)); @@ -255,7 +258,7 @@ class PoolingOp : public PoolingOpBase { paddings_.data(), dilations_.data(), strides_.data(), - RoundType::CEIL, + round_type_, output_shape.data()); } MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape)); @@ -442,7 +445,7 @@ class PoolingOp : public PoolingOpBase { return kernel_->Compute(context, input, pooling_type_, kernels_.data(), strides_.data(), padding_type_, paddings_, - dilations_.data(), output); + dilations_.data(), round_type_, output); } private: diff --git a/mace/ops/reduce.cc b/mace/ops/reduce.cc new file mode 100644 index 0000000000000000000000000000000000000000..c14cd48cdad15dd4bdd29a935aa071d5ba84a7e8 --- /dev/null +++ b/mace/ops/reduce.cc @@ -0,0 +1,574 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/reduce.h" + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/operator.h" +#include "mace/core/runtime/cpu/cpu_runtime.h" +#include "mace/core/tensor.h" +#ifdef MACE_ENABLE_OPENCL +#include "mace/ops/opencl/image/reduce.h" +#endif // MACE_ENABLE_OPENCL + +namespace mace { +namespace ops { + +class ReduceOpBase : public Operation { + public: + explicit ReduceOpBase(OpConstructContext *context) + : Operation(context), + reduce_type_( + static_cast(Operation::GetOptionalArg( + "reduce_type", static_cast(MEAN)))), + axis_(Operation::GetRepeatedArgs("axis")), + keep_dims_(Operation::GetOptionalArg("keepdims", false)) { + } + + protected: + inline void Validate() { + const Tensor *input = this->Input(0); + const int left = static_cast(input->dim_size() * -1); + const int right = static_cast(input->dim_size()); + if (axis_.size()) { + for (unsigned int i = 0; i < axis_.size(); ++i) { + MACE_CHECK(axis_[i] > left && axis_[i] < right, "Axis is over range."); + } + } + } + + protected: + ReduceType reduce_type_; + std::vector axis_; + bool keep_dims_; +}; + +template +class ReduceOp; + +template +class ReduceOp : public ReduceOpBase { + public: + explicit ReduceOp(OpConstructContext *context) + : ReduceOpBase(context) {} + + MaceStatus Run(OpContext *context) override { + MACE_UNUSED(context); + Validate(); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + Simplify(input); + output->Resize(out_shape_); + Compute(input, output); + return MaceStatus::MACE_SUCCESS; + } + + private: + void Simplify(const Tensor *input) { + std::vector bitmap(static_cast(input->dim_size()), false); + if (axis_.size() == 0) { + for (int i = 0; i < input->dim_size(); ++i) { + bitmap[i] = true; + } + } else { + for (unsigned int i = 0; i < axis_.size(); ++i) { + int index = axis_[i] >= 0 ? + axis_[i] : + axis_[i] + input->dim_size(); + if (input->dim_size() == 4) { + if (index == 1 || index == 2) index = index + 1; + else if (index == 3) index = 1; + } + bitmap[index] = true; + } + } + out_shape_.clear(); + for (unsigned int i = 0; i < input->dim_size(); ++i) { + if (!bitmap[i]) { + out_shape_.push_back(input->dim(i)); + } else if (keep_dims_) { + out_shape_.push_back(1); + } + } + data_reshape_.clear(); + unsigned int dim_index = 0; + for (; dim_index < input->dim_size(); ++dim_index) { + if (input->dim(dim_index) != 1) break; + } + if (dim_index >= input->dim_size()) { + reduce_first_axis_ = true; + } else { + reduce_first_axis_ = bitmap[dim_index]; + data_reshape_.push_back(input->dim(dim_index)); + ++dim_index; + for (; dim_index < input->dim_size(); ++dim_index) { + const int n = input->dim(dim_index); + if (n == 1) { + bitmap[dim_index] = bitmap[dim_index - 1]; + } + if (bitmap[dim_index-1] != bitmap[dim_index]) { + data_reshape_.push_back(n); + } else { + data_reshape_.back() *= n; + } + } + } + } + + void compute_reduce_1(const T *input, ReduceType type, T *output) { + if (reduce_first_axis_) { + if (type == ReduceType::MEAN) { + T tmp = 0; + for (int i = 0; i < data_reshape_[0]; ++i) { + tmp = tmp + input[i]; + } + output[0] = tmp / data_reshape_[0]; + } else if (type == ReduceType::MIN) { + T tmp = input[0]; + for (int i = 1; i < data_reshape_[0]; ++i) { + tmp = std::min(tmp, input[i]); + } + output[0] = tmp; + } else if (type == ReduceType::MAX) { + T tmp = input[0]; + for (int i = 1; i < data_reshape_[0]; ++i) { + tmp = std::max(tmp, input[i]); + } + output[0] = tmp; + } else if (type == ReduceType::PROD) { + T tmp = input[0]; + for (int i = 1; i < data_reshape_[0]; ++i) { + tmp = tmp * input[i]; + } + output[0] = tmp; + } else { + MACE_NOT_IMPLEMENTED; + } + } else { + memcpy(output, input, data_reshape_[0] * sizeof(T)); + } + } + + void compute_reduce_2(const T *input, ReduceType type, T *output) { + if (reduce_first_axis_) { + if (type == ReduceType::MEAN) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = 0; + for (int j = 0; j < data_reshape_[0]; ++j) { + tmp += input[j * data_reshape_[1] + i]; + } + output[i] = tmp / data_reshape_[0]; + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = input[i]; + for (int j = 1; j < data_reshape_[0]; ++j) { + tmp = std::min(tmp, input[j * data_reshape_[1] + i]); + } + output[i] = tmp; + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = input[i]; + for (int j = 1; j < data_reshape_[0]; ++j) { + tmp = std::max(tmp, input[j * data_reshape_[1] + i]); + } + output[i] = tmp; + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = input[i]; + for (int j = 1; j < data_reshape_[0]; ++j) { + tmp = tmp * input[j * data_reshape_[1] + i]; + } + output[i] = tmp; + } + } else { + MACE_NOT_IMPLEMENTED; + } + } else { + if (type == ReduceType::MEAN) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + T tmp = 0; + for (int j = 0; j < data_reshape_[1]; ++j) { + tmp += input[i * data_reshape_[1] + j]; + } + output[i] = tmp / data_reshape_[1]; + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + T tmp = input[i * data_reshape_[1]]; + for (int j = 1; j < data_reshape_[1]; ++j) { + tmp = std::min(tmp, input[i * data_reshape_[1] + j]); + } + output[i] = tmp; + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + T tmp = input[i * data_reshape_[1]]; + for (int j = 1; j < data_reshape_[1]; ++j) { + tmp = std::max(tmp, input[i * data_reshape_[1] + j]); + } + output[i] = tmp; + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + T tmp = input[i * data_reshape_[1]]; + for (int j = 1; j < data_reshape_[1]; ++j) { + tmp = tmp * input[i * data_reshape_[1] + j]; + } + output[i] = tmp; + } + } else { + MACE_NOT_IMPLEMENTED; + } + } + } + + void compute_reduce_3(const T *input, ReduceType type, T *output) { + if (reduce_first_axis_) { + if (type == ReduceType::MEAN) { +#pragma omp parallel for collapse(1) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[0]; ++k) { + output[i] += + input[(k * data_reshape_[1] + i) * data_reshape_[2] + + j]; + } + } + output[i] /= (data_reshape_[0] * data_reshape_[2]); + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for collapse(1) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = input[i * data_reshape_[2]]; + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[0]; ++k) { + tmp = std::min(tmp, + input[(k * data_reshape_[1] + i) * data_reshape_[2] + + j]); + } + } + output[i] = tmp; + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for collapse(1) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = input[i * data_reshape_[2]]; + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[0]; ++k) { + tmp = + std::max(tmp, + input[(k * data_reshape_[1] + i) + * data_reshape_[2] + j]); + } + } + output[i] = tmp; + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + T tmp = 1; + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[0]; ++k) { + tmp *= + input[(k * data_reshape_[1] + i) * data_reshape_[2] + + j]; + } + } + output[i] = tmp; + } + } else { + MACE_NOT_IMPLEMENTED; + } + } else { + if (type == ReduceType::MEAN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[1]; ++k) { + output[i * data_reshape_[2] + j] += + input[(i * data_reshape_[1] + k) * data_reshape_[2] + + j]; + } + output[i * data_reshape_[2] + j] /= data_reshape_[1]; + } + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j]; + for (int k = 1; k < data_reshape_[1]; ++k) { + tmp = std::min(tmp, + input[(i * data_reshape_[1] + k) * + data_reshape_[2] + j]); + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j]; + for (int k = 1; k < data_reshape_[1]; ++k) { + tmp = std::max(tmp, + input[(i * data_reshape_[1] + k) * + data_reshape_[2] + j]); + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = input[i * data_reshape_[1] * data_reshape_[2] + j]; + for (int k = 1; k < data_reshape_[1]; ++k) { + tmp *= input[(i * data_reshape_[1] + k) * + data_reshape_[2] + j]; + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else { + MACE_NOT_IMPLEMENTED; + } + } + } + + void compute_reduce_4(const T *input, ReduceType type, T *output) { + if (reduce_first_axis_) { + if (type == ReduceType::MEAN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[3]; ++j) { + for (int k = 0; k < data_reshape_[2]; ++k) { + for (int t = 0; t < data_reshape_[0]; ++t) { + output[i * data_reshape_[3] + j] += + input[((t * data_reshape_[1] + i) * + data_reshape_[2] + k)*data_reshape_[3] + j]; + } + } + output[i * data_reshape_[3] + j] /= + (data_reshape_[0] * data_reshape_[2]); + } + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[3]; ++j) { + T tmp = input[i * data_reshape_[2] * data_reshape_[3] + j]; + for (int k = 0; k < data_reshape_[2]; ++k) { + for (int t = 0; t < data_reshape_[0]; ++t) { + tmp = std::min(tmp, + input[((t * data_reshape_[1] + i) * + data_reshape_[2] + k)*data_reshape_[3] + j]); + } + } + output[i * data_reshape_[3] + j] = tmp; + } + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[3]; ++j) { + T tmp = input[i * data_reshape_[2] * data_reshape_[3] + j]; + for (int k = 0; k < data_reshape_[2]; ++k) { + for (int t = 0; t < data_reshape_[0]; ++t) { + tmp = std::max(tmp, + input[((t * data_reshape_[1] + i) * + data_reshape_[2] + k)*data_reshape_[3] + j]); + } + } + output[i * data_reshape_[3] + j] = tmp; + } + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[1]; ++i) { + for (int j = 0; j < data_reshape_[3]; ++j) { + T tmp = 1; + for (int k = 0; k < data_reshape_[2]; ++k) { + for (int t = 0; t < data_reshape_[0]; ++t) { + tmp = tmp * input[((t * data_reshape_[1] + i) * + data_reshape_[2] + k)*data_reshape_[3] + j]; + } + } + output[i * data_reshape_[3] + j] = tmp; + } + } + } else { + MACE_NOT_IMPLEMENTED; + } + } else { + if (type == ReduceType::MEAN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + for (int k = 0; k < data_reshape_[1]; ++k) { + for (int t = 0; t < data_reshape_[3]; ++t) { + output[i * data_reshape_[2] + j] += + input[((i * data_reshape_[1] + k) * + data_reshape_[2] + j)*data_reshape_[3] + t]; + } + } + output[i * data_reshape_[2] + j] /= + (data_reshape_[1] * data_reshape_[3]); + } + } + } else if (type == ReduceType::MIN) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = input[(i * data_reshape_[1] * + data_reshape_[2] + j)*data_reshape_[3]]; + for (int k = 0; k < data_reshape_[1]; ++k) { + for (int t = 0; t < data_reshape_[3]; ++t) { + tmp = + std::min(tmp, + input[((i * data_reshape_[1] + k) * + data_reshape_[2] + j)*data_reshape_[3] + t]); + } + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else if (type == ReduceType::MAX) { +#pragma omp parallel for collapse(2) schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = input[(i * data_reshape_[1] * + data_reshape_[2] + j)*data_reshape_[3]]; + for (int k = 0; k < data_reshape_[1]; ++k) { + for (int t = 0; t < data_reshape_[3]; ++t) { + tmp = + std::max(tmp, + input[((i * data_reshape_[1] + k) * + data_reshape_[2] + j)*data_reshape_[3] + t]); + } + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else if (type == ReduceType::PROD) { +#pragma omp parallel for schedule(runtime) + for (int i = 0; i < data_reshape_[0]; ++i) { + for (int j = 0; j < data_reshape_[2]; ++j) { + T tmp = 1; + for (int k = 0; k < data_reshape_[1]; ++k) { + for (int t = 0; t < data_reshape_[3]; ++t) { + tmp = tmp * input[((i * data_reshape_[1] + k) * + data_reshape_[2] + j)*data_reshape_[3] + t]; + } + } + output[i * data_reshape_[2] + j] = tmp; + } + } + } else { + MACE_NOT_IMPLEMENTED; + } + } + } + + + void Compute(const Tensor *input, Tensor *output) { + Tensor::MappingGuard input_mapper(input); + const T *input_ptr = input->data(); + Tensor::MappingGuard output_map(output); + T *output_ptr = output->mutable_data(); + memset(output_ptr, 0, output->size() * sizeof(T)); + switch (data_reshape_.size()) { + case 1: + compute_reduce_1(input_ptr, reduce_type_, output_ptr); + break; + case 2: + compute_reduce_2(input_ptr, reduce_type_, output_ptr); + break; + case 3: + compute_reduce_3(input_ptr, reduce_type_, output_ptr); + break; + case 4: + compute_reduce_4(input_ptr, reduce_type_, output_ptr); + break; + default: + MACE_CHECK(false, "not implemented in mace") + << "data reshape size" << data_reshape_.size() + << "reduce first axis:" << reduce_first_axis_; + break; + } + } + + private: + bool reduce_first_axis_; + std::vector data_reshape_; + std::vector out_shape_; +}; + +#ifdef MACE_ENABLE_OPENCL +template +class ReduceOp : public ReduceOpBase { + public: + explicit ReduceOp(OpConstructContext *context) + : ReduceOpBase(context) { + if (context->device()->gpu_runtime()->UseImageMemory()) { + kernel_.reset(new opencl::image::ReduceKernel(reduce_type_, + axis_, + keep_dims_)); + } else { + MACE_NOT_IMPLEMENTED; + } + } + MaceStatus Run(OpContext *context) override { + Validate(); + const Tensor *input = this->Input(0); + Tensor *output = this->Output(0); + + return kernel_->Compute(context, input, output); + } + + private: + std::unique_ptr kernel_; +}; +#endif // MACE_ENABLE_OPENCL + +void RegisterReduce(OpRegistryBase *op_registry) { + MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, + DeviceType::CPU, float); + +#ifdef MACE_ENABLE_OPENCL + MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, + DeviceType::GPU, float); + + MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, + DeviceType::GPU, half); +#endif // MACE_ENABLE_OPENCL +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/reduce.h b/mace/ops/reduce.h new file mode 100644 index 0000000000000000000000000000000000000000..460aae809d4532f1240c97d5874cc646f852e3e7 --- /dev/null +++ b/mace/ops/reduce.h @@ -0,0 +1,31 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_REDUCE_H_ +#define MACE_OPS_REDUCE_H_ + + +namespace mace { +enum ReduceType { +// SUM = 0, + MEAN = 0, + MIN = 1, + MAX = 2, + PROD = 3, +// SUM_SQR = 4, +// SQR_MEAN = 5, +}; +} // namespace mace + +#endif // MACE_OPS_REDUCE_H_ diff --git a/mace/ops/reduce_mean_benchmark.cc b/mace/ops/reduce_benchmark.cc similarity index 71% rename from mace/ops/reduce_mean_benchmark.cc rename to mace/ops/reduce_benchmark.cc index 60a255009c3b614c90aeb2607dc3c5e78ef2472e..ec8807b0488892de1ac22eb5136dc5b524482c27 100644 --- a/mace/ops/reduce_mean_benchmark.cc +++ b/mace/ops/reduce_benchmark.cc @@ -21,7 +21,7 @@ namespace test { namespace { template -void ReduceMean(int iters, int batch, int channels, +void Reduce(int iters, int batch, int channels, int height, int width) { mace::testing::StopTiming(); @@ -34,7 +34,7 @@ void ReduceMean(int iters, int batch, int channels, net.AddRandomInput("Input", {batch, channels, height, width}); } - OpDefBuilder("ReduceMean", "ReduceMeanBM") + OpDefBuilder("Reduce", "ReduceBM") .Input("Input") .AddIntsArg("axis", axis) .Output("OutputImage") @@ -55,30 +55,30 @@ void ReduceMean(int iters, int batch, int channels, } } // namespace -#define MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, TYPE, DEVICE) \ +#define MACE_BM_REDUCE_MACRO(N, C, H, W, TYPE, DEVICE) \ static void \ - MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\ + MACE_BM_REDUCE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE(\ int iters) { \ const int64_t tot = static_cast(iters) * N * C * H * W; \ mace::testing::MaccProcessed(tot); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ - ReduceMean(iters, N, C, H, W); \ + Reduce(iters, N, C, H, W); \ } \ MACE_BENCHMARK( \ - MACE_BM_REDUCE_MEAN_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) + MACE_BM_REDUCE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE) -#define MACE_BM_REDUCE_MEAN(N, C, H, W) \ - MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, GPU); \ - MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, half, GPU); \ - MACE_BM_REDUCE_MEAN_MACRO(N, C, H, W, float, CPU); +#define MACE_BM_REDUCE(N, C, H, W) \ + MACE_BM_REDUCE_MACRO(N, C, H, W, float, GPU); \ + MACE_BM_REDUCE_MACRO(N, C, H, W, half, GPU); \ + MACE_BM_REDUCE_MACRO(N, C, H, W, float, CPU); -MACE_BM_REDUCE_MEAN(1, 1, 512, 512); -MACE_BM_REDUCE_MEAN(4, 3, 128, 128); -MACE_BM_REDUCE_MEAN(4, 1, 512, 512); -MACE_BM_REDUCE_MEAN(16, 32, 112, 112); -MACE_BM_REDUCE_MEAN(8, 64, 256, 256); -MACE_BM_REDUCE_MEAN(1, 32, 480, 640); +MACE_BM_REDUCE(1, 1, 512, 512); +MACE_BM_REDUCE(4, 3, 128, 128); +MACE_BM_REDUCE(4, 1, 512, 512); +MACE_BM_REDUCE(16, 32, 112, 112); +MACE_BM_REDUCE(8, 64, 256, 256); +MACE_BM_REDUCE(1, 32, 480, 640); } // namespace test diff --git a/mace/ops/reduce_mean.cc b/mace/ops/reduce_mean.cc deleted file mode 100644 index 863103b28fc607aa4003840ee72aefa88b917312..0000000000000000000000000000000000000000 --- a/mace/ops/reduce_mean.cc +++ /dev/null @@ -1,282 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include - -#include "mace/core/future.h" -#include "mace/core/operator.h" -#include "mace/core/tensor.h" -#ifdef MACE_ENABLE_OPENCL -#include "mace/ops/opencl/image/reduce_mean.h" -#endif // MACE_ENABLE_OPENCL - -namespace mace { -namespace ops { - -class ReduceMeanOpBase : public Operation { - public: - explicit ReduceMeanOpBase(OpConstructContext *context) - : Operation(context), - axis_(Operation::GetRepeatedArgs("axis")), - keep_dims_(Operation::GetOptionalArg("keepdims", false)) { - } - - protected: - inline void Validate() { - const Tensor *input = this->Input(0); - const int left = static_cast(input->dim_size() * -1); - const int right = static_cast(input->dim_size()); - if (axis_.size()) { - for (unsigned int i = 0; i < axis_.size(); ++i) { - MACE_CHECK(axis_[i] > left && axis_[i] < right, "Axis is over range."); - } - } - } - - protected: - std::vector axis_; - bool keep_dims_; -}; - -template -class ReduceMeanOp; - -template -class ReduceMeanOp : public ReduceMeanOpBase { - public: - explicit ReduceMeanOp(OpConstructContext *context) - : ReduceMeanOpBase(context) { - } - - MaceStatus Run(OpContext *context) override { - MACE_UNUSED(context); - Validate(); - const Tensor *input = this->Input(0); - Tensor *output = this->Output(0); - Simplify(input); - output->Resize(out_shape_); - Compute(input, output); - return MaceStatus::MACE_SUCCESS; - } - - private: - void Simplify(const Tensor *input) { - std::vector bitmap(static_cast(input->dim_size()), false); - if (axis_.size() == 0) { - for (int i = 0; i < input->dim_size(); ++i) { - bitmap[i] = true; - } - } else { - for (unsigned int i = 0; i < axis_.size(); ++i) { - int index = axis_[i] >= 0 ? - axis_[i] : - axis_[i] + input->dim_size(); - // axis format is NHWC - if (input->dim_size() == 4) { - if (index == 1) index = 2; - else if (index == 2) index = 3; - else if (index == 3) index = 1; - } - bitmap[index] = true; - } - } - out_shape_.clear(); - for (unsigned int i = 0; i < input->dim_size(); ++i) { - if (!bitmap[i]) { - out_shape_.push_back(input->dim(i)); - } else if (keep_dims_) { - out_shape_.push_back(1); - } - } - data_reshape_.clear(); - unsigned int dim_index = 0; - for (; dim_index < input->dim_size(); ++dim_index) { - if (input->dim(dim_index) != 1) break; - } - if (dim_index >= input->dim_size()) { - reduce_first_axis_ = true; - } else { - reduce_first_axis_ = bitmap[dim_index]; - data_reshape_.push_back(input->dim(dim_index)); - ++dim_index; - for (; dim_index < input->dim_size(); ++dim_index) { - const int n = input->dim(dim_index); - if (n == 1) { - bitmap[dim_index] = bitmap[dim_index - 1]; - } - if (bitmap[dim_index-1] != bitmap[dim_index]) { - data_reshape_.push_back(n); - } else { - data_reshape_.back() *= n; - } - } - } - } - - void Compute(const Tensor *input, Tensor *output) { - Tensor::MappingGuard input_mapper(input); - const T *input_ptr = input->data(); - Tensor::MappingGuard output_map(output); - T *output_ptr = output->mutable_data(); - memset(output_ptr, 0, output->size() * sizeof(T)); - switch (data_reshape_.size()) { - case 1: - if (reduce_first_axis_) { - T sum = 0; - for (int i = 0; i < data_reshape_[0]; ++i) { - sum = sum + input_ptr[i]; - } - output_ptr[0] = sum / data_reshape_[0]; - } else { -#pragma omp parallel for schedule(runtime) - for (int i = 0; i < data_reshape_[0]; ++i) { - output_ptr[i] = input_ptr[i]; - } - } - break; - case 2: - if (reduce_first_axis_) { -#pragma omp parallel for schedule(runtime) - for (int i = 0; i < data_reshape_[1]; ++i) { - for (int j = 0; j < data_reshape_[0]; ++j) { - output_ptr[i] += input_ptr[j * data_reshape_[1] + i]; - } - output_ptr[i] /= data_reshape_[0]; - } - } else { -#pragma omp parallel for schedule(runtime) - for (int i = 0; i < data_reshape_[0]; ++i) { - for (int j = 0; j < data_reshape_[1]; ++j) { - output_ptr[i] += input_ptr[i * data_reshape_[1] + j]; - } - output_ptr[i] /= data_reshape_[1]; - } - } - break; - case 3: - if (reduce_first_axis_) { -#pragma omp parallel for schedule(runtime) - for (int i = 0; i < data_reshape_[1]; ++i) { - for (int j = 0; j < data_reshape_[2]; ++j) { - for (int k = 0; k < data_reshape_[0]; ++k) { - output_ptr[i] += - input_ptr[(k * data_reshape_[1] + i) * data_reshape_[2] - + j]; - } - } - output_ptr[i] /= (data_reshape_[0] * data_reshape_[2]); - } - } else { -#pragma omp parallel for collapse(2) schedule(runtime) - for (int i = 0; i < data_reshape_[0]; ++i) { - for (int j = 0; j < data_reshape_[2]; ++j) { - for (int k = 0; k < data_reshape_[1]; ++k) { - output_ptr[i * data_reshape_[2] + j] += - input_ptr[(i * data_reshape_[1] + k) * data_reshape_[2] - + j]; - } - output_ptr[i * data_reshape_[2] + j] /= data_reshape_[1]; - } - } - } - break; - case 4: - if (reduce_first_axis_) { -#pragma omp parallel for collapse(2) schedule(runtime) - for (int i = 0; i < data_reshape_[1]; ++i) { - for (int j = 0; j < data_reshape_[3]; ++j) { - for (int k = 0; k < data_reshape_[2]; ++k) { - for (int t = 0; t < data_reshape_[0]; ++t) { - output_ptr[i * data_reshape_[3] + j] += - input_ptr[((t * data_reshape_[1] + i) * - data_reshape_[2] + k)*data_reshape_[3] + j]; - } - } - output_ptr[i * data_reshape_[3] + j] /= - (data_reshape_[0] * data_reshape_[2]); - } - } - } else { -#pragma omp parallel for collapse(2) schedule(runtime) - for (int i = 0; i < data_reshape_[0]; ++i) { - for (int j = 0; j < data_reshape_[2]; ++j) { - for (int k = 0; k < data_reshape_[1]; ++k) { - for (int t = 0; t < data_reshape_[3]; ++t) { - output_ptr[i * data_reshape_[2] + j] += - input_ptr[((i * data_reshape_[1] + k) * - data_reshape_[2] + j)*data_reshape_[3] + t]; - } - } - output_ptr[i * data_reshape_[2] + j] /= - (data_reshape_[1] * data_reshape_[3]); - } - } - } - break; - default: - MACE_CHECK(false, "not implemented in mace") - << "data reshape size" << data_reshape_.size() - << "reduce first axis:" << reduce_first_axis_; - break; - } - } - - private: - bool reduce_first_axis_; - std::vector data_reshape_; - std::vector out_shape_; -}; - -#ifdef MACE_ENABLE_OPENCL -template -class ReduceMeanOp : public ReduceMeanOpBase { - public: - explicit ReduceMeanOp(OpConstructContext *context) - : ReduceMeanOpBase(context) { - if (context->device()->gpu_runtime()->UseImageMemory()) { - kernel_.reset(new opencl::image::ReduceMeanKernel(axis_, keep_dims_)); - } else { - MACE_NOT_IMPLEMENTED; - } - } - MaceStatus Run(OpContext *context) override { - Validate(); - const Tensor *input = this->Input(0); - Tensor *output = this->Output(0); - - return kernel_->Compute(context, input, output); - } - - private: - std::unique_ptr kernel_; -}; -#endif // MACE_ENABLE_OPENCL - -void RegisterReduceMean(OpRegistryBase *op_registry) { - MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp, - DeviceType::CPU, float); - -#ifdef MACE_ENABLE_OPENCL - MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp, - DeviceType::GPU, float); - - MACE_REGISTER_OP(op_registry, "ReduceMean", ReduceMeanOp, - DeviceType::GPU, half); -#endif // MACE_ENABLE_OPENCL -} - -} // namespace ops -} // namespace mace diff --git a/mace/ops/reduce_mean_test.cc b/mace/ops/reduce_mean_test.cc deleted file mode 100644 index ef455f85a4cf0961fb24975b47fe88640d2e7150..0000000000000000000000000000000000000000 --- a/mace/ops/reduce_mean_test.cc +++ /dev/null @@ -1,425 +0,0 @@ -// Copyright 2018 Xiaomi, Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "mace/ops/ops_test_util.h" - -namespace mace { -namespace ops { -namespace test { - -class ReduceMeanOpTest : public OpsTestBase {}; - -namespace { -template -void Simple(const std::vector &input_shape, - const std::vector &input, - const std::vector &axis, - const std::vector &output_shape, - const std::vector &output, - const bool keepdims = true) { - // Construct graph - OpsTestNet net; - // Add input data - net.AddInputFromArray("Input", input_shape, input); - - if (D == DeviceType::CPU) { - net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); - OpDefBuilder("ReduceMean", "ReduceMeanTest") - .Input("InputNCHW") - .AddIntsArg("axis", axis) - .AddIntArg("keepdims", keepdims ? 1 : 0) - .Output("OutputNCHW") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - net.TransformDataFormat("OutputNCHW", NCHW, "Output", NHWC); - } else { - OpDefBuilder("ReduceMean", "ReduceMeanTest") - .Input("Input") - .AddIntsArg("axis", axis) - .AddIntArg("keepdims", keepdims ? 1 : 0) - .Output("Output") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - } - auto expected = net.CreateTensor(output_shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); -} - -template -void Simple3D(const std::vector &input_shape, - const std::vector &input, - const std::vector &axis, - const std::vector &output_shape, - const std::vector &output, - const bool keepdims = true) { - // Construct graph - OpsTestNet net; - // Add input data - net.AddInputFromArray("Input", input_shape, input); - - OpDefBuilder("ReduceMean", "ReduceMeanTest") - .Input("Input") - .AddIntsArg("axis", axis) - .AddIntArg("keepdims", keepdims ? 1 : 0) - .Output("Output") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - auto expected = net.CreateTensor(output_shape, output); - ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); -} - -template -void Simple12Test() { - Simple({2, 2, 3, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, - {1, 2}, - {2, 1, 1, 4}, - {10, 11, 12, 13, - 10, 11, 12, 13}); -} - -template -void Simple1Axis() { - Simple({2, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23, - 0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {1}, - {2, 1, 3, 4}, - {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, - 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {-3}, - {1, 1, 3, 4}, - {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {2}, - {1, 2, 1, 4}, - {4, 5, 6, 7, 16, 17, 18, 19}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {-1}, - {1, 2, 3, 1}, - {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {1}, - {1, 1, 3, 3}, - {9, 10, 11, 12, 13, 14, 15, 16, 17}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {-2}, - {1, 3, 1, 3}, - {3, 4, 5, 12, 13, 14, 21, 22, 23}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {3}, - {1, 3, 3, 1}, - {1, 4, 7, 10, 13, 16, 19, 22, 25}); -} - -template -void Simple2Axis() { - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 1}, - {1, 1, 3, 4}, - {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 2}, - {1, 2, 1, 4}, - {4, 5, 6, 7, 16, 17, 18, 19}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {1, 3}, - {1, 1, 3, 1}, - {7.5, 11.5, 15.5}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {1, 2}, - {1, 1, 1, 3}, - {12, 13, 14}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {0, 1}, - {1, 1, 3, 3}, - {9, 10, 11, 12, 13, 14, 15, 16, 17}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {2, 3}, - {1, 3, 1, 1}, - {4, 13, 22}); -} - -template -void Simple2Axis3D() { - Simple3D({2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 1}, - {1, 1, 4}, - {10, 11, 12, 13}); - Simple3D({2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {1, 2}, - {2, 1, 1}, - {5.5, 17.5}); -} - - -template -void Simple3Axis() { - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {1, 2, 3}, - {1, 1, 1, 1}, - {11.5}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 2, 3}, - {1, 2, 1, 1}, - {5.5, 17.5}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 1, 3}, - {1, 1, 3, 1}, - {7.5, 11.5, 15.5}); - Simple({1, 2, 3, 4}, - {0, 1, 2, 3, - 4, 5, 6, 7, - 8, 9, 10, 11, - 12, 13, 14, 15, - 16, 17, 18, 19, - 20, 21, 22, 23}, - {0, 1, 2}, - {1, 1, 1, 4}, - {10, 11, 12, 13}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {1, 2, 3}, - {1, 1, 1, 1}, - {13}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {0, 2, 3}, - {1, 3, 1, 1}, - {4, 13, 22}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {0, 1, 3}, - {1, 1, 3, 1}, - {10, 13, 16}); - Simple({1, 3, 3, 3}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, - 9, 10, 11, 12, 13, 14, 15, 16, 17, - 18, 19, 20, 21, 22, 23, 24, 25, 26}, - {0, 1, 2}, - {1, 1, 1, 3}, - {12, 13, 14}); -} - -} // namespace - -TEST_F(ReduceMeanOpTest, CPUSimple12) { - Simple12Test(); -} - -TEST_F(ReduceMeanOpTest, GPUSimple12) { - Simple12Test(); -} - -TEST_F(ReduceMeanOpTest, CPUSimple1Axis) { - Simple1Axis(); -} - -TEST_F(ReduceMeanOpTest, CPUSimple2Axis) { - Simple2Axis(); -} - -TEST_F(ReduceMeanOpTest, CPUSimple2Axis3D) { - Simple2Axis3D(); -} - -TEST_F(ReduceMeanOpTest, CPUSimple3Axis) { - Simple3Axis(); -} - -TEST_F(ReduceMeanOpTest, CPUSimpleReduceDims) { - Simple3D({2, 3, 4}, - {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, - 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, - {0, 1}, - {4}, - {10, 11, 12, 13}, - false); -} - -namespace { -template -void RandomTest(const std::vector &input_shape, - const std::vector &axis) { - testing::internal::LogToStderr(); - srand(time(NULL)); - // Construct graph - OpsTestNet net; - // Add input data - net.AddRandomInput("Input", input_shape); - - net.TransformDataFormat("Input", NHWC, "InputNCHW", - NCHW); - OpDefBuilder("ReduceMean", "ReduceMeanTest") - .Input("InputNCHW") - .AddIntsArg("axis", axis) - .AddIntArg("keepdims", 1) - .Output("OutputNCHW") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(); - net.TransformDataFormat("OutputNCHW", NCHW, - "Output", NHWC); - OpDefBuilder("ReduceMean", "ReduceMeanTest") - .Input("Input") - .AddIntsArg("axis", axis) - .AddIntArg("keepdims", 1) - .Output("OPENCLOutput") - .Finalize(net.NewOperatorDef()); - // Run - net.RunOp(D); - if (DataTypeToEnum::value == DT_FLOAT) { - ExpectTensorNear(*net.GetTensor("Output"), - *net.GetOutput("OPENCLOutput"), 1e-5, 1e-4); - } else { - ExpectTensorNear(*net.GetTensor("Output"), - *net.GetOutput("OPENCLOutput"), 1e-2, 1e-2); - } -} -} // namespace - -TEST_F(ReduceMeanOpTest, GPURandomFloat) { - RandomTest({4, 64, 64, 3}, {1, 2}); - RandomTest({2, 64, 64, 4}, {1, 2}); - RandomTest({8, 128, 128, 64}, {1, 2}); - RandomTest({1, 640, 480, 64}, {1, 2}); - RandomTest({1, 480, 640, 32}, {1, 2}); - RandomTest({1, 512, 512, 16}, {1, 2}); - RandomTest({8, 117, 87, 33}, {1, 2}); - RandomTest({1, 619, 450, 61}, {1, 2}); - RandomTest({1, 511, 561, 11}, {1, 2}); -} - -TEST_F(ReduceMeanOpTest, GPURandomHalf) { - RandomTest({4, 64, 64, 3}, {1, 2}); - RandomTest({2, 64, 64, 4}, {1, 2}); - RandomTest({8, 128, 128, 64}, {1, 2}); - RandomTest({1, 640, 480, 64}, {1, 2}); - RandomTest({1, 480, 640, 32}, {1, 2}); - RandomTest({1, 512, 512, 16}, {1, 2}); - RandomTest({8, 117, 87, 33}, {1, 2}); - RandomTest({1, 619, 450, 61}, {1, 2}); - RandomTest({1, 511, 561, 11}, {1, 2}); -} - -} // namespace test -} // namespace ops -} // namespace mace diff --git a/mace/ops/reduce_test.cc b/mace/ops/reduce_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..e9e804e953270f8970fd8987a8a50fe58ea2831a --- /dev/null +++ b/mace/ops/reduce_test.cc @@ -0,0 +1,644 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "mace/ops/reduce.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class ReduceOpTest : public OpsTestBase {}; + +namespace { +template +void Simple(const std::vector &input_shape, + const std::vector &input, + const std::vector &axis, + const std::vector &output_shape, + const std::vector &output, + ReduceType type, + const bool keepdims = true) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", input_shape, input); + + if (D == DeviceType::CPU) { + net.TransformDataFormat("Input", NHWC, "InputNCHW", NCHW); + OpDefBuilder("Reduce", "ReduceTest") + .Input("InputNCHW") + .AddIntsArg("axis", axis) + .AddIntArg("keepdims", keepdims ? 1 : 0) + .AddIntArg("reduce_type", type) + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + net.TransformDataFormat("OutputNCHW", NCHW, "Output", NHWC); + } else { + OpDefBuilder("Reduce", "ReduceTest") + .Input("Input") + .AddIntsArg("axis", axis) + .AddIntArg("keepdims", keepdims ? 1 : 0) + .AddIntArg("reduce_type", type) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + } + auto expected = net.CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); +} + +template +void Simple3D(const std::vector &input_shape, + const std::vector &input, + const std::vector &axis, + const std::vector &output_shape, + const std::vector &output, + ReduceType type, + const bool keepdims = true) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", input_shape, input); + + OpDefBuilder("Reduce", "ReduceTest") + .Input("Input") + .AddIntsArg("axis", axis) + .AddIntArg("keepdims", keepdims ? 1 : 0) + .AddIntArg("reduce_type", type) + .Output("Output") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + auto expected = net.CreateTensor(output_shape, output); + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5, 1e-3); +} + +template +void SimpleMean12Test() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1, 4}, + {10, 11, 12, 13, + 10, 11, 12, 13}, ReduceType::MEAN); +} + +// template +// void SimpleSum12Test() { +// Simple({2, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {1, 2}, +// {2, 1, 1, 4}, +// {60, 66, 72, 78, +// 60, 66, 72, 78}, ReduceType::SUM); +//} + +template +void SimpleMin12Test() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1, 4}, + {0, 1, 2, 3, + 0, 1, 2, 3}, ReduceType::MIN); +} + +template +void SimpleMax12Test() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1, 4}, + {20, 21, 22, 23, + 20, 21, 22, 23}, ReduceType::MAX); +} + +// template +// void SimpleSumSqr12Test() { +// Simple({2, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {1, 2}, +// {2, 1, 1, 4}, +// {880, 1006, 1144, 1294, +// 880, 1006, 1144, 1294}, ReduceType::SUM_SQR); +//} + + +template +void SimpleMean1Axis() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {1}, + {2, 1, 3, 4}, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, + 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {-3}, +// {1, 1, 3, 4}, +// {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {2}, +// {1, 2, 1, 4}, +// {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {-1}, +// {1, 2, 3, 1}, +// {1.5, 5.5, 9.5, 13.5, 17.5, 21.5}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1}, +// {1, 1, 3, 3}, +// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {-2}, +// {1, 3, 1, 3}, +// {3, 4, 5, 12, 13, 14, 21, 22, 23}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {3}, +// {1, 3, 3, 1}, +// {1, 4, 7, 10, 13, 16, 19, 22, 25}, ReduceType::MEAN); +} + +// template +// void SimpleSum1Axis() { +// Simple({2, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23, +// 0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {1}, +// {2, 1, 3, 4}, +// {12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, +// 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34}); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {2}, +// {1, 2, 1, 4}, +// {12, 15, 18, 21, 48, 51, 54, 57}, ReduceType::SUM); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {-1}, +// {1, 2, 3, 1}, +// {6, 22, 38, 54, 70, 86}, ReduceType::SUM); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1}, +// {1, 1, 3, 3}, +// {27, 30, 33, 36, 39, 42, 45, 48, 51}, ReduceType::SUM); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {3}, +// {1, 3, 3, 1}, +// {3, 12, 21, 30, 39, 48, 57, 66, 75}, ReduceType::SUM); +//} + +template +void SimpleMin1Axis() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1}, + {2, 1, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, ReduceType::MIN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {2}, +// {1, 2, 1, 4}, +// {0, 1, 2, 3, 12, 13, 14, 15}, ReduceType::MIN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {-1}, +// {1, 2, 3, 1}, +// {0, 4, 8, 12, 16, 20}, ReduceType::MIN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1}, +// {1, 1, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8}, ReduceType::MIN); +} + +template +void SimpleMax1Axis() { + Simple({2, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1}, + {2, 1, 3, 4}, + {12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, ReduceType::MAX); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {2}, +// {1, 2, 1, 4}, +// {8, 9, 10, 11, 20, 21, 22, 23}, ReduceType::MAX); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {-1}, +// {1, 2, 3, 1}, +// {3, 7, 11, 15, 19, 23}, ReduceType::MAX); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1}, +// {1, 1, 3, 3}, +// {18, 19, 20, 21, 22, 23, 24, 25, 26}, ReduceType::MAX); +} + +// template +// void SimpleSumSqr1Axis() { +// Simple({2, 2, 3, 4}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, +// 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, +// 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, +// {1}, +// {2, 1, 3, 4}, +// {144, 170, 200, 234, +// 272, 314, 360, 410, +// 464, 522, 584, 650, +// 144, 170, 200, 234, +// 272, 314, 360, 410, +// 464, 522, 584, 650}, ReduceType::SUM_SQR); +//} + + +template +void Simple2Axis() { + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 1}, + {1, 1, 3, 4}, + {6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); +// Simple3D({2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {0, 1}, +// {1, 1, 4}, +// {10, 11, 12, 13}, ReduceType::MEAN); + Simple3D({2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1, 2}, + {2, 1, 1}, + {5.5, 17.5}, ReduceType::MEAN); + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {0, 2}, + {1, 2, 1, 4}, + {4, 5, 6, 7, 16, 17, 18, 19}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {1, 3}, +// {1, 1, 3, 1}, +// {7.5, 11.5, 15.5}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1, 2}, +// {1, 1, 1, 3}, +// {12, 13, 14}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {0, 1}, +// {1, 1, 3, 3}, +// {9, 10, 11, 12, 13, 14, 15, 16, 17}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {2, 3}, +// {1, 3, 1, 1}, +// {4, 13, 22}, ReduceType::MEAN); +} + +template +void Simple3Axis() { + Simple({1, 2, 3, 4}, + {0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19, + 20, 21, 22, 23}, + {1, 2, 3}, + {1, 1, 1, 1}, + {11.5}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {0, 2, 3}, +// {1, 2, 1, 1}, +// {5.5, 17.5}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {0, 1, 3}, +// {1, 1, 3, 1}, +// {7.5, 11.5, 15.5}, ReduceType::MEAN); +// Simple({1, 2, 3, 4}, +// {0, 1, 2, 3, +// 4, 5, 6, 7, +// 8, 9, 10, 11, +// 12, 13, 14, 15, +// 16, 17, 18, 19, +// 20, 21, 22, 23}, +// {0, 1, 2}, +// {1, 1, 1, 4}, +// {10, 11, 12, 13}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {1, 2, 3}, +// {1, 1, 1, 1}, +// {13}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {0, 2, 3}, +// {1, 3, 1, 1}, +// {4, 13, 22}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {0, 1, 3}, +// {1, 1, 3, 1}, +// {10, 13, 16}, ReduceType::MEAN); +// Simple({1, 3, 3, 3}, +// {0, 1, 2, 3, 4, 5, 6, 7, 8, +// 9, 10, 11, 12, 13, 14, 15, 16, 17, +// 18, 19, 20, 21, 22, 23, 24, 25, 26}, +// {0, 1, 2}, +// {1, 1, 1, 3}, +// {12, 13, 14}, ReduceType::MEAN); +} + +} // namespace + +TEST_F(ReduceOpTest, CPUSimple12) { + SimpleMean12Test(); + SimpleMin12Test(); + SimpleMax12Test(); +} + +TEST_F(ReduceOpTest, GPUSimple12) { + SimpleMean12Test(); + SimpleMin12Test(); + SimpleMax12Test(); +} + +TEST_F(ReduceOpTest, CPUSimple1Axis) { + SimpleMean1Axis(); + SimpleMin1Axis(); + SimpleMax1Axis(); +} + +TEST_F(ReduceOpTest, CPUSimple2Axis) { + Simple2Axis(); +} + +TEST_F(ReduceOpTest, CPUSimple3Axis) { + Simple3Axis(); +} + +TEST_F(ReduceOpTest, CPUSimpleReduceDims) { + Simple3D({2, 3, 4}, + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {0, 1}, + {4}, + {10, 11, 12, 13}, ReduceType::MEAN, + false); +} + +namespace { +template +void RandomTest(const std::vector &input_shape, + const std::vector &axis) { + testing::internal::LogToStderr(); + srand(time(NULL)); + auto func = [&](ReduceType type) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddRandomInput("Input", input_shape); + + net.TransformDataFormat("Input", NHWC, "InputNCHW", + NCHW); + OpDefBuilder("Reduce", "ReduceTest") + .Input("InputNCHW") + .AddIntsArg("axis", axis) + .AddIntArg("keepdims", 1) + .AddIntArg("reduce_type", type) + .Output("OutputNCHW") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(); + net.TransformDataFormat("OutputNCHW", NCHW, + "Output", NHWC); + OpDefBuilder("Reduce", "ReduceTest") + .Input("Input") + .AddIntsArg("axis", axis) + .AddIntArg("keepdims", 1) + .AddIntArg("reduce_type", type) + .Output("OPENCLOutput") + .Finalize(net.NewOperatorDef()); + // Run + net.RunOp(D); + if (DataTypeToEnum::value == DT_FLOAT) { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-5, 1e-4); + } else { + ExpectTensorNear(*net.GetTensor("Output"), + *net.GetOutput("OPENCLOutput"), 1e-2, 1e-2); + } + }; + + for (ReduceType type : {MEAN, MIN, MAX, PROD}) { + func(type); + } +} +} // namespace + +TEST_F(ReduceOpTest, GPURandomFloat) { + RandomTest({4, 64, 64, 3}, {1, 2}); +// RandomTest({2, 64, 64, 4}, {1, 2}); + RandomTest({8, 128, 128, 64}, {1, 2}); +// RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 480, 640, 32}, {1, 2}); +// RandomTest({1, 512, 512, 16}, {1, 2}); + RandomTest({8, 117, 87, 33}, {1, 2}); +// RandomTest({1, 619, 450, 61}, {1, 2}); + RandomTest({1, 511, 561, 11}, {1, 2}); +} + +TEST_F(ReduceOpTest, GPURandomHalf) { + RandomTest({4, 64, 64, 3}, {1, 2}); +// RandomTest({2, 64, 64, 4}, {1, 2}); + RandomTest({8, 128, 128, 64}, {1, 2}); +// RandomTest({1, 640, 480, 64}, {1, 2}); + RandomTest({1, 480, 640, 32}, {1, 2}); +// RandomTest({1, 512, 512, 16}, {1, 2}); + RandomTest({8, 117, 87, 33}, {1, 2}); +// RandomTest({1, 619, 450, 61}, {1, 2}); + RandomTest({1, 511, 561, 11}, {1, 2}); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 400d1cff051fe09dfee156330cc21d70e7121144..30c7ce890290139e22d319807ce19eb30afc928f 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -36,6 +36,7 @@ class ReshapeOp : public Operation { int unknown_idx = -1; index_t product = 1; std::vector out_shape; + index_t n = 0; for (int i = 0; i < num_dims; ++i) { if (shape_data[i] == -1) { @@ -45,8 +46,15 @@ class ReshapeOp : public Operation { } else { MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ", shape_data[i]); - out_shape.push_back(shape_data[i]); - product *= shape_data[i]; + if (shape_data[i] == 0) { + MACE_CHECK(i < input->dim_size(), + "dims:0 out of input dims' range."); + n = input->dim(i); + } else { + n = shape_data[i]; + } + out_shape.push_back(n); + product *= n; } } diff --git a/mace/python/tools/BUILD b/mace/python/tools/BUILD index 41f039476ee7f6b50a15ac8cac1dc30dc7738121..f89bacd5ec0afbeeeda1ae8b3590c50196354f86 100644 --- a/mace/python/tools/BUILD +++ b/mace/python/tools/BUILD @@ -13,6 +13,7 @@ py_library( "converter_tool/base_converter.py", "converter_tool/caffe_converter.py", "converter_tool/hexagon_converter.py", + "converter_tool/onnx_converter.py", "converter_tool/shape_inference.py", "converter_tool/tensorflow_converter.py", "converter_tool/tf_dsp_converter.py", diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index f53876d63d0809e05151e8e11f6dc23f0f54ca54..fe337e997f09ce9baff0f2bb7c357f7487b0831d 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -101,7 +101,7 @@ def main(unused_args): file=sys.stderr) sys.exit(-1) - if FLAGS.platform not in ['tensorflow', 'caffe']: + if FLAGS.platform not in ['tensorflow', 'caffe', 'onnx']: six.print_("platform %s is not supported." % FLAGS.platform, file=sys.stderr) sys.exit(-1) @@ -188,6 +188,9 @@ def main(unused_args): converter = caffe_converter.CaffeConverter(option, FLAGS.model_file, FLAGS.weight_file) + elif FLAGS.platform == 'onnx': + from mace.python.tools.converter_tool import onnx_converter + converter = onnx_converter.OnnxConverter(option, FLAGS.model_file) else: six.print_("Mace do not support platorm %s yet." % FLAGS.platform, file=sys.stderr) @@ -231,6 +234,7 @@ def parse_args(): type=str, default="", help="TensorFlow \'GraphDef\' file to load, " + "Onnx model file .onnx to load, " "Caffe prototxt file to load.") parser.add_argument( "--weight_file", type=str, default="", help="Caffe data file to load.") @@ -300,7 +304,10 @@ def parse_args(): parser.add_argument( "--check_shape", type=str, default="", help="check shape.") parser.add_argument( - "--platform", type=str, default="tensorflow", help="tensorflow/caffe") + "--platform", + type=str, + default="tensorflow", + help="tensorflow/caffe/onnx") parser.add_argument( "--embed_model_data", type=str2bool, diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 3f8d7164b64ec4d64253f8562b54e9e7b31f377d..fa748ed474ac11ae2ed2476040ebad513d58e167 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -37,11 +37,14 @@ class FilterFormat(Enum): OHWI = 103 +# SAME_LOWER: if the amount of paddings to be added is odd, +# it will add the extra data to the right or bottom class PaddingMode(Enum): VALID = 0 SAME = 1 FULL = 2 - NA = 3 + SAME_LOWER = 3 + NA = 4 class PoolingType(Enum): @@ -49,6 +52,11 @@ class PoolingType(Enum): MAX = 2 +class RoundMode(Enum): + FLOOR = 0 + CEIL = 1 + + class ActivationType(Enum): NOOP = 0 RELU = 1 @@ -56,6 +64,7 @@ class ActivationType(Enum): PRELU = 3 TANH = 4 SIGMOID = 5 + LEAKYRELU = 6 class EltwiseType(Enum): @@ -72,9 +81,17 @@ class EltwiseType(Enum): EQUAL = 10 +class ReduceType(Enum): + MEAN = 0 + MIN = 1 + MAX = 2 + PROD = 3 + + class FrameworkType(Enum): TENSORFLOW = 0 CAFFE = 1 + ONNX = 2 MaceSupportedOps = [ @@ -108,7 +125,7 @@ MaceSupportedOps = [ 'Pooling', 'Proposal', 'Quantize', - 'ReduceMean', + 'Reduce', 'Reshape', 'ResizeBicubic', 'ResizeBilinear', @@ -184,6 +201,10 @@ class MaceKeyword(object): mace_group_str = "group" mace_wino_arg_str = "wino_block_size" mace_quantize_flag_arg_str = "quantize_flag" + mace_epsilon_str = 'epsilon' + mace_reduce_type_str = 'reduce_type' + mace_argmin_str = 'argmin' + mace_round_mode_str = 'round_mode' class TransformerRule(Enum): diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py new file mode 100644 index 0000000000000000000000000000000000000000..a801285500f5f69d9b5044e3367c7495bdd469bb --- /dev/null +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -0,0 +1,947 @@ +# Copyright 2018 Xiaomi, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import sys +from enum import Enum +import six + +from mace.proto import mace_pb2 +from mace.python.tools.converter_tool import base_converter +from mace.python.tools.converter_tool.base_converter import PoolingType +from mace.python.tools.converter_tool.base_converter import PaddingMode +from mace.python.tools.converter_tool.base_converter import ActivationType +from mace.python.tools.converter_tool.base_converter import EltwiseType +from mace.python.tools.converter_tool.base_converter import ReduceType +from mace.python.tools.converter_tool.base_converter import FrameworkType +from mace.python.tools.converter_tool.base_converter import RoundMode +from mace.python.tools.converter_tool.base_converter import DataFormat +from mace.python.tools.converter_tool.base_converter import FilterFormat +from mace.python.tools.converter_tool.base_converter import MaceOp +from mace.python.tools.converter_tool.base_converter import MaceKeyword +from mace.python.tools.converter_tool.base_converter import ConverterUtil +from mace.python.tools.convert_util import mace_check + +import onnx +import onnx.utils +from onnx import helper, shape_inference, numpy_helper, optimizer +import numpy as np +from onnx import mapping +from onnx import TensorProto +from numbers import Number + + +OnnxSupportedOps = [ + 'Abs', + # 'Acos', + # 'Acosh', + 'Add', + # 'And', + 'ArgMax', + 'ArgMin', + # 'Asin', + # 'Asinh', + # 'Atan', + # 'Atanh', + 'AveragePool', + 'BatchNormalization', + 'Cast', + # 'Ceil', + # 'Clip', + # 'Compress', + 'Concat', + # 'Constant', + # 'ConstantLike', + 'Conv', + 'ConvTranspose', + # 'Cos', + # 'Cosh', + 'DepthToSpace', + 'Div', + 'Dropout', + 'Elu', + 'Equal', + # 'Exp', + # 'Expand', + # 'EyeLike', + # 'Flatten', + # 'Floor', + # 'GRU', + 'Gather', + 'Gemm', + 'GlobalAveragePool', + # 'GlobalLpPool', + 'GlobalMaxPool', + # 'Greater', + # 'HardSigmoid', + # 'Hardmax', + 'Identity', + # 'If', + 'ImageScaler', + # 'InstanceNormalization', + # 'LRN', + # 'LSTM', + 'LeakyRelu', + # 'Less', + # 'Log', + # 'LogSoftmax', + # 'Loop', + # 'LpNormalization', + # 'LpPool', + 'MatMul', + 'Max', + 'MaxPool', + # 'MaxRoiPool', + # 'MaxUnpool', + 'Mean', + 'Min', + 'Mul', + # 'Multinomial', + 'Neg', + # 'Not', + # 'OneHot', + # 'Or', + 'PRelu', + 'Pad', + 'Pow', + # 'RNN', + # 'RandomNormal', + # 'RandonNormalLike', + # 'RandonUniform', + # 'RandonUniformLike', + 'Reciprocal', + # 'ReduceL1', + # 'ReduceL2', + # 'ReduceLogSum', + # 'ReduceLogSumExp', + 'ReduceMax', + 'ReduceMean', + 'ReduceMin', + 'ReduceProd', + # 'ReduceSum', + # 'ReduceSumSquare', + 'Relu', + 'Reshape', + # 'Scan', + # 'Selu', + 'Shape', + 'Sigmoid', + # 'Sin', + # 'Sinh', + # 'Size', + # 'Slice', + 'Softmax', + # 'Softplus', + # 'Softsign', + 'SpaceToDepth', + 'Split', + 'Sqrt', + 'Squeeze', + 'Sub', + 'Sum', + # 'Tan', + 'Tanh', + # 'Tile', + # 'TopK', + 'Transpose', + # 'Unsqueeze', + # 'Upsample', + # 'Xor', +] + +OnnxOpType = Enum('OnnxOpType', + [(op, op) for op in OnnxSupportedOps], + type=str) + +onnx_attr_translator = { + "axis": lambda x: int(x), + "axes": lambda x: [int(a) for a in x], + "dtype": lambda x: data_type.onnx2tf(x), + "keepdims": lambda x: bool(x), + "to": lambda x: data_type.onnx2tf(x), +} + + +def translate_onnx(key, val): + return onnx_attr_translator.get(key, lambda x: x)(val) + + +def convert_onnx(attr): + return convert_onnx_attribute_proto(attr) + + +def convert_onnx_attribute_proto(attr_proto): + if attr_proto.HasField('f'): + return attr_proto.f + elif attr_proto.HasField('i'): + return attr_proto.i + elif attr_proto.HasField('s'): + return str(attr_proto.s, 'utf-8')\ + if sys.version_info.major == 3 else attr_proto.s + elif attr_proto.HasField('t'): + return attr_proto.t # this is a proto! + elif attr_proto.floats: + return list(attr_proto.floats) + elif attr_proto.ints: + return list(attr_proto.ints) + elif attr_proto.strings: + str_list = list(attr_proto.strings) + if IS_PYTHON3: + str_list = map(lambda x: str(x, 'utf-8'), str_list) + return str_list + else: + raise ValueError("Unsupported ONNX attribute: {}".format(attr_proto)) + + +def onnx_dtype(dtype): + if isinstance(dtype, Number): + onnx_dtype = dtype + elif isinstance(dtype, str): + onnx_dtype = TensorProto.DataType.Value(dtype) + else: + raise RuntimeError("dtype should be number or str.") + return mapping.TENSOR_TYPE_TO_NP_TYPE[onnx_dtype] + + +class OnnxNode(object): + def __init__(self, node): + self.name = str(node.name) + self.op_type = str(node.op_type) + self.domain = str(node.domain) + self.attrs = dict([(attr.name, + translate_onnx(attr.name, convert_onnx(attr))) + for attr in node.attribute]) + self.inputs = list(node.input) + self.outputs = list(node.output) + self.node_proto = node + + def print_info(self): + print "node: ", self.name + print " type: ", self.op_type + print " domain: ", self.domain + print " inputs: ", self.inputs + print " outputs: ", self.outputs + print " attrs:" + for arg in self.attrs: + print " %s: %s" % (arg, self.attrs[arg]) + + +class OnnxTensor(object): + def __init__(self, name, value, shape, dtype): + self._name = name + self._tensor_data = value + self._shape = shape + self._dtype = dtype + + +class OnnxConverter(base_converter.ConverterInterface): + pooling_type_mode = { + OnnxOpType.AveragePool.name: PoolingType.AVG, + OnnxOpType.MaxPool.name: PoolingType.MAX + } + + auto_pad_mode = { + 'NOTSET': PaddingMode.NA, + 'SAME_UPPER': PaddingMode.SAME, + 'SAME_LOWER': PaddingMode.SAME, + 'VALID': PaddingMode.VALID, + } + auto_pad_mode = {six.b(k): v for k, v in six.iteritems(auto_pad_mode)} + + eltwise_type = { + OnnxOpType.Mul.name: EltwiseType.PROD, + OnnxOpType.Add.name: EltwiseType.SUM, + OnnxOpType.Max.name: EltwiseType.MAX, + OnnxOpType.Min.name: EltwiseType.MIN, + OnnxOpType.Abs.name: EltwiseType.ABS, + OnnxOpType.Pow.name: EltwiseType.POW, + OnnxOpType.Sub.name: EltwiseType.SUB, + OnnxOpType.Div.name: EltwiseType.DIV, + OnnxOpType.Neg.name: EltwiseType.NEG, + OnnxOpType.Sum.name: EltwiseType.SUM, + OnnxOpType.Equal.name: EltwiseType.EQUAL, + OnnxOpType.Sqrt.name: EltwiseType.POW, + OnnxOpType.Reciprocal.name: EltwiseType.POW, + } + + reduce_type = { + OnnxOpType.GlobalAveragePool.name: ReduceType.MEAN, + OnnxOpType.GlobalMaxPool.name: ReduceType.MAX, + OnnxOpType.ReduceMax.name: ReduceType.MAX, + OnnxOpType.ReduceMean.name: ReduceType.MEAN, + OnnxOpType.ReduceMin.name: ReduceType.MIN, + OnnxOpType.ReduceProd.name: ReduceType.PROD, + } + + activation_type = { + OnnxOpType.Relu.name: ActivationType.RELU, + OnnxOpType.PRelu.name: ActivationType.PRELU, + OnnxOpType.Tanh.name: ActivationType.TANH, + OnnxOpType.Sigmoid.name: ActivationType.SIGMOID, + OnnxOpType.LeakyRelu.name: ActivationType.LEAKYRELU, + } + + def __init__(self, option, src_model_file): + self._op_converters = { + OnnxOpType.Abs.name: self.convert_eltwise, + OnnxOpType.Add.name: self.convert_eltwise, + OnnxOpType.ArgMax.name: self.convert_argmax, + OnnxOpType.ArgMin.name: self.convert_argmax, + OnnxOpType.AveragePool.name: self.convert_pooling, + OnnxOpType.BatchNormalization.name: self.convert_fused_batchnorm, + OnnxOpType.Cast.name: self.convert_cast, + OnnxOpType.Concat.name: self.convert_concat, + OnnxOpType.Conv.name: self.convert_conv2d, + OnnxOpType.ConvTranspose.name: self.convert_deconv, + OnnxOpType.DepthToSpace.name: self.convert_depth_space, + OnnxOpType.Dropout.name: self.convert_identity, + OnnxOpType.Div.name: self.convert_eltwise, + OnnxOpType.Equal.name: self.convert_eltwise, + OnnxOpType.Gather.name: self.convert_gather, + OnnxOpType.Gemm.name: self.convert_fully_connected, + OnnxOpType.GlobalAveragePool.name: self.convert_reduce, + OnnxOpType.GlobalMaxPool.name: self.convert_reduce, + OnnxOpType.Identity.name: self.convert_identity, + OnnxOpType.ImageScaler.name: self.convert_imagescaler, + OnnxOpType.LeakyRelu.name: self.convert_activation, + OnnxOpType.Max.name: self.convert_eltwise, + OnnxOpType.MaxPool.name: self.convert_pooling, + OnnxOpType.MatMul.name: self.convert_matmul, + OnnxOpType.Min.name: self.convert_eltwise, + OnnxOpType.Mul.name: self.convert_eltwise, + OnnxOpType.Neg.name: self.convert_eltwise, + OnnxOpType.Pad.name: self.convert_pad, + OnnxOpType.Pow.name: self.convert_eltwise, + OnnxOpType.PRelu.name: self.convert_activation, + OnnxOpType.Relu.name: self.convert_activation, + OnnxOpType.Reshape.name: self.convert_reshape, + OnnxOpType.Reciprocal.name: self.convert_eltwise, + OnnxOpType.Sigmoid.name: self.convert_activation, + OnnxOpType.Softmax.name: self.convert_softmax, + OnnxOpType.SpaceToDepth.name: self.convert_depth_space, + OnnxOpType.Split.name: self.convert_split, + OnnxOpType.Sqrt.name: self.convert_eltwise, + OnnxOpType.Squeeze.name: self.convert_squeeze, + OnnxOpType.Sub.name: self.convert_eltwise, + OnnxOpType.Sum.name: self.convert_eltwise, + OnnxOpType.Tanh.name: self.convert_activation, + OnnxOpType.Transpose.name: self.convert_transpose, + } + self._option = option + self._mace_net_def = mace_pb2.NetDef() + ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.OIHW) + onnx_model = onnx.load(src_model_file) + + polished_model = onnx.utils.polish_model(onnx_model) + + print "onnx model IR version: ", onnx_model.ir_version + print "onnx model opset import: ", onnx_model.opset_import + + self._onnx_model = shape_inference.infer_shapes(polished_model) + self._graph_shapes_dict = {} + self._consts = {} + self._replace_tensors = {} + + def print_graph_info(self, graph): + for value_info in graph.value_info: + print "value info:", value_info + for value_info in graph.input: + print "inputs info:", value_info + for value_info in graph.output: + print "outputs info:", value_info + + def extract_shape_info(self, graph): + def extract_value_info(shape_dict, value_info): + t = tuple([int(dim.dim_value) + for dim in value_info.type.tensor_type.shape.dim]) + if t: + shape_dict[value_info.name] = t + + for value_info in graph.value_info: + extract_value_info(self._graph_shapes_dict, value_info) + for value_info in graph.input: + extract_value_info(self._graph_shapes_dict, value_info) + for value_info in graph.output: + extract_value_info(self._graph_shapes_dict, value_info) + + def add_tensor(self, name, shape, data_type, value): + tensor = self._mace_net_def.tensors.add() + tensor.name = name + tensor.dims.extend(list(shape)) + tensor.data_type = data_type + tensor.float_data.extend(value.flat) + + def run(self): + graph_def = self._onnx_model.graph + self.extract_shape_info(graph_def) + self.convert_tensors(graph_def) + self.convert_ops(graph_def) + # self.print_graph_info(graph_def) + # shape_inferer = mace_shape_inference.ShapeInference( + # self._mace_net_def, + # self._option.input_nodes.values()) + # shape_inferer.run() + return self._mace_net_def + + def add_stride_pad_kernel_arg(self, attrs, op_def): + if 'strides' in attrs: + strides = attrs['strides'] + mace_check(len(strides) == 2, "strides should has 2 values.") + stride = [strides[0], strides[1]] + else: + stride = [1, 1] + + strides_arg = op_def.arg.add() + strides_arg.name = MaceKeyword.mace_strides_str + strides_arg.ints.extend(stride) + + if 'kernel_shape' in attrs: + kernel_shape = attrs['kernel_shape'] + mace_check(len(kernel_shape) == 2, + "kernel shape should has 2 values.") + kernel = [kernel_shape[0], kernel_shape[1]] + kernels_arg = op_def.arg.add() + kernels_arg.name = MaceKeyword.mace_kernel_str + kernels_arg.ints.extend(kernel) + + if 'pads' in attrs: + pads = attrs['pads'] + if len(pads) == 4: + pad = [pads[0] + pads[2], pads[1] + pads[3]] + else: + pad = [0, 0] + padding_arg = op_def.arg.add() + padding_arg.name = MaceKeyword.mace_padding_values_str + padding_arg.ints.extend(pad) + elif 'auto_pad' in attrs: + auto_pad_arg = op_def.arg.add() + auto_pad_arg.name = MaceKeyword.mace_padding_str + auto_pad_arg.i = self.auto_pad_mode[attrs['auto_pad']].value + else: + pad = [0, 0] + padding_arg = op_def.arg.add() + padding_arg.name = MaceKeyword.mace_padding_values_str + padding_arg.ints.extend(pad) + + def convert_ops(self, graph_def): + for n in graph_def.node: + node = OnnxNode(n) + mace_check(node.op_type in self._op_converters, + "Mace does not support onnx op type %s yet" + % node.op_type) + self._op_converters[node.op_type](node) + + def convert_tensors(self, graph_def): + initializer = graph_def.initializer + if initializer: + for init in initializer: + tensor = self._mace_net_def.tensors.add() + tensor.name = init.name + + onnx_tensor = numpy_helper.to_array(init) + tensor.dims.extend(list(init.dims)) + data_type = onnx_dtype(init.data_type) + + if data_type == np.float32 or data_type == np.float64: + tensor.data_type = mace_pb2.DT_FLOAT + tensor.float_data.extend( + onnx_tensor.astype(np.float32).flat) + elif data_type == np.int32: + tensor.data_type = mace_pb2.DT_INT32 + tensor.int32_data.extend( + onnx_tensor.astype(np.int32).flat) + elif data_type == np.int64: + tensor.data_type = mace_pb2.DT_INT32 + tensor.int32_data.extend( + onnx_tensor.astype(np.int32).flat) + else: + mace_check(False, + "Not supported tensor type: %s" % data_type) + self._consts[tensor.name] = tensor + + def convert_general_op(self, node): + op = self._mace_net_def.op.add() + op.name = node.name + + for input in node.inputs: + if input in self._replace_tensors: + input = self._replace_tensors[input] + op.input.append(input) + for output in node.outputs: + op.output.append(output) + output_shape = op.output_shape.add() + shape_info = self._graph_shapes_dict[output] + output_shape.dims.extend(shape_info) + + data_type_arg = op.arg.add() + data_type_arg.name = 'T' + data_type_arg.i = self._option.data_type + + framework_type_arg = op.arg.add() + framework_type_arg.name = MaceKeyword.mace_framework_type_str + framework_type_arg.i = FrameworkType.ONNX.value + + ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) + return op + + def convert_fused_batchnorm(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.BatchNorm.name + + if "epsilon" in node.attrs: + epsilon_value = node.attrs["epsilon"] + else: + epsilon_value = 1e-5 + + mace_check(len(node.inputs) == 5, "batch norm should have 5 inputs.") + + gamma_value = np.array(self._consts[node.inputs[1]].float_data) + beta_value = np.array(self._consts[node.inputs[2]].float_data) + mean_value = np.array(self._consts[node.inputs[3]].float_data) + var_value = np.array(self._consts[node.inputs[4]].float_data) + + scale_name = node.name + 'scale' + offset_name = node.name + 'offset' + scale_value = ( + (1.0 / np.sqrt( + var_value + epsilon_value)) * gamma_value) + offset_value = (-mean_value * scale_value) + beta_value + self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, + scale_value) + self.add_tensor(offset_name, offset_value.shape, mace_pb2.DT_FLOAT, + offset_value) + del op.input[1:] + op.input.extend([scale_name, offset_name]) + del op.output[1:] + del op.output_shape[1:] + + def convert_conv2d(self, node): + op = self.convert_general_op(node) + self.add_stride_pad_kernel_arg(node.attrs, op) + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str + if 'group' in node.attrs: + group_val = node.attrs["group"] + else: + group_val = 1 + group_arg.i = group_val + + is_depthwise = False + if group_val > 1: + filter_shape = self._graph_shapes_dict[node.inputs[1]] + mace_check(group_val == filter_shape[0] and + filter_shape[1] == 1, + "Mace does not support group convolution yet") + filter_tensor = self._consts[node.inputs[1]] + new_shape = [filter_shape[1], filter_shape[0], + filter_shape[2], filter_shape[3]] + del filter_tensor.dims[:] + filter_tensor.dims.extend(new_shape) + is_depthwise = True + if is_depthwise: + op.type = MaceOp.DepthwiseConv2d.name + else: + op.type = MaceOp.Conv2D.name + + dilation_arg = op.arg.add() + dilation_arg.name = MaceKeyword.mace_dilations_str + if 'dilations' in node.attrs: + dilation_val = node.attrs["dilations"] + else: + dilation_val = [1, 1] + dilation_arg.ints.extend(dilation_val) + + def convert_biasadd(self, node): + self.convert_general_op(node) + + def convert_concat(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Concat.name + mace_check('axis' in node.attrs, + 'Concat op should have axis attribute.') + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = node.attrs['axis'] + axis_arg.i = 4 + axis_arg.i if axis_arg.i < 0 else axis_arg.i + mace_check(axis_arg.i == 1, + "only support concat at channel dimension") + + def convert_activation(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Activation.name + + type_arg = op.arg.add() + type_arg.name = MaceKeyword.mace_activation_type_str + type_arg.s = six.b(self.activation_type[node.op_type].name) + + if "alpha" in node.attrs: + alpha_value = node.attrs["alpha"] + else: + alpha_value = 0 + alpha_arg = op.arg.add() + alpha_arg.name = MaceKeyword.mace_activation_max_limit_str + alpha_arg.f = alpha_value + + def convert_pooling(self, node): + op = self.convert_general_op(node) + + op.type = MaceOp.Pooling.name + self.add_stride_pad_kernel_arg(node.attrs, op) + pooling_type_arg = op.arg.add() + pooling_type_arg.name = MaceKeyword.mace_pooling_type_str + pooling_type_arg.i = self.pooling_type_mode[node.op_type].value + + round_mode_arg = op.arg.add() + round_mode_arg.name = MaceKeyword.mace_round_mode_str + round_mode_arg.i = RoundMode.FLOOR.value + + def convert_reshape(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reshape.name + + def convert_flatten(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reshape.name + + def remove_node(self, node): + input_name = node.inputs[0] + output_name = node.outputs[0] + self._replace_tensors[output_name] = input_name + + def convert_eltwise(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Eltwise.name + type_arg = op.arg.add() + type_arg.name = MaceKeyword.mace_element_type_str + type_arg.i = self.eltwise_type[node.op_type].value + + if node.op_type == OnnxOpType.Sqrt.name: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = 0.5 + elif node.op_type == OnnxOpType.Reciprocal.name: + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = -1 + + def convert_reduce(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Reduce.name + + reduce_type_arg = op.arg.add() + reduce_type_arg.name = MaceKeyword.mace_reduce_type_str + reduce_type_arg.i = self.reduce_type[node.op_type].value + + if node.op_type in [OnnxOpType.GlobalAveragePool.name, + OnnxOpType.GlobalMaxPool.name]: + reduce_dims = [2, 3] + keep_dims = 1 + else: + if 'axes' in node.attrs: + reduce_dims = node.attrs['axes'] + else: + reduce_dims = [] + if 'keepdims' in node.attrs: + keep_dims = node.attrs['keepdims'] + else: + keep_dims = 1 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.ints.extend(reduce_dims) + + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str + keep_dims_arg.i = keep_dims + + def convert_imagescaler(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.BatchNorm.name + + scale = node.attrs['scale'] + bias_value = np.array(node.attrs['bias']) + scale_value = scale * np.ones_like(bias_value) + + scale_name = node.name + "_scale" + bias_name = node.name + "_bias" + self.add_tensor(scale_name, scale_value.shape, mace_pb2.DT_FLOAT, + scale_value) + self.add_tensor(bias_name, bias_value.shape, mace_pb2.DT_FLOAT, + bias_value) + op.input.extend([scale_name, bias_name]) + + def convert_matmul(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.MatMul.name + + def convert_softmax(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Softmax.name + + def convert_argmax(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.ArgMax.name + + if 'axis' in node.attrs: + axis_value = node.attrs['axis'] + else: + axis_value = 0 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = axis_value + + if 'keepdims' in node.attrs: + keepdims = node.attrs['keepdims'] + else: + keepdims = 1 + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str + keep_dims_arg.i = keepdims + + if node.op_type == OnnxOpType.ArgMin.name: + min_arg = op.arg.add() + min_arg.name = MaceKeyword.mace_argmin_str + min_arg.i = 1 + + def convert_cast(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Cast.name + + if 'to' in node.attrs: + dtype = node.attrs['to'] + if dtype == TensorProto.FLOAT: + op.output_type.extend([self._option.data_type]) + elif dtype == TensorProto.INT: + op.output_type.extend([mace_pb2.DT_INT32]) + else: + mace_check(False, "data type %s not supported" % dtype) + else: + op.output_type.extend([self._option.data_type]) + + def convert_depth_space(self, node): + op = self.convert_general_op(node) + if op.type == OnnxOpType.DepthToSpace.name: + op.type = MaceOp.DepthToSpace.name + else: + op.type = MaceOp.SpaceToDepth.name + mace_check(('block_size' in node.attrs), + "depth to space op should have block size attribute.") + block_size = node.attrs['block_size'] + size_arg = op.arg.add() + size_arg.name = MaceKeyword.mace_space_depth_block_size_str + size_arg.i = block_size + + def convert_deconv(self, node): + op = self.convert_general_op(node) + + self.add_stride_pad_kernel_arg(node.attrs, op) + + if 'group' in node.attrs: + group_val = node.attrs["group"] + else: + group_val = 1 + if group_val > 1: + op.type = MaceOp.DepthwiseDeconv2d.name + filter_shape = self._graph_shapes_dict[node.inputs[1]] + filter_tensor = self._consts[node.inputs[1]] + new_shape = [filter_shape[1], filter_shape[0], + filter_shape[2], filter_shape[3]] + del filter_tensor.dims[:] + filter_tensor.dims.extend(new_shape) + else: + op.type = MaceOp.Deconv2D.name + group_arg = op.arg.add() + group_arg.name = MaceKeyword.mace_group_str + group_arg.i = group_val + + dilation_arg = op.arg.add() + dilation_arg.name = MaceKeyword.mace_dilations_str + if 'dilations' in node.attrs: + dilation_val = node.attrs["dilations"] + else: + dilation_val = [1, 1] + dilation_arg.ints.extend(dilation_val) + mace_check(dilation_val == [1, 1], + "not support convtranspose with dilation != 1 yet.") + + mace_check('output_padding' not in node.attrs, + "not support convtranspose with output_padding yet.") + mace_check('output_shape' not in node.attrs, + "not support convtranspose with output_shape yet.") + # TODO: if output shape specified, calculate padding value + # if 'output_padding' in node.attrs: + # output_padding = node.attrs['output_padding'] + # output_padding_arg = op.arg.add() + # output_padding_arg.name = MaceKeyword.mace_output_padding_str + # output_padding_arg.ints.extend(output_padding) + # if 'output_shape' in node.attrs: + # output_shape = node.attrs['output_shape'] + # output_shape_arg = op.arg.add() + # output_shape_arg.name = MaceKeyword.mace_output_shape_str + # output_shape_arg.ints.extend(output_shape) + + def convert_nop(self, node): + pass + + def convert_identity(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Identity.name + + def convert_pad(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Pad.name + + if 'pads' in node.attrs: + paddings_arg = op.arg.add() + paddings_arg.name = MaceKeyword.mace_paddings_str + paddings_value = node.attrs['pads'] + paddings_arg.ints.extend(paddings_value) + + if 'value' in node.attrs: + constant_value_arg = op.arg.add() + constant_value_arg.name = MaceKeyword.mace_constant_value_str + constant_value_arg.i = node.attrs['value'] + + def convert_gather(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Gather.name + + if 'axis' in node.attrs: + value = node.attrs['axis'] + else: + value = 0 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = value + + def convert_split(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Split.name + + if 'axis' in node.attrs: + value = node.attrs['axis'] + else: + value = 0 + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + axis_arg.i = value + + def convert_transpose(self, node): + op = self.convert_general_op(node) + op.type = MaceOp.Transpose.name + + if np.array_equal(perm, ordered_perm): + op.type = MaceOp.Identity.name + del op.input[1:] + if 'perm' in node.attrs: + perm = node.attrs['perm'] + ordered_perm = np.sort(perm) + if np.array_equal(perm, ordered_perm): + op.type = MaceOp.Identity.name + else: + dims_arg = op.arg.add() + dims_arg.name = MaceKeyword.mace_dims_str + dims_arg.ints.extend(perm) + + @staticmethod + def squeeze_shape(shape, axis): + new_shape = [] + if len(axis) > 0: + for i in range(len(shape)): + if i not in axis: + new_shape.append(shape[i]) + else: + new_shape = shape + return new_shape + + def convert_squeeze(self, node): + axis_value = node.attrs['axes'] + if node.inputs[0] in self._consts: + tensor = self._consts[node.inputs[0]] + shape = tensor.dims + new_shape = self.squeeze_shape(shape, axis_value) + del tensor.dims[:] + tensor.dims.extend(new_shape) + self.remove_node(node) + else: + op = self.convert_general_op(node) + op.type = MaceOp.Squeeze.name + axis_arg = op.arg.add() + axis_arg.name = MaceKeyword.mace_axis_str + if 'axis' in node.attrs: + axis_value = node.attrs['axis'] + else: + axis_value = [] + axis_arg.ints.extend(axis_value) + + @staticmethod + def transpose_const(tensor): + shape = tensor.dims + mace_check(len(shape) == 2, "gemm only supports 2-dim input.") + tensor_data = np.array(tensor.float_data).reshape( + shape[0], shape[1]) + tensor_data = tensor_data.transpose(1, 0) + tensor.float_data[:] = tensor_data.flat + tensor.dims[:] = tensor_data.shape + + def convert_fully_connected(self, node): + trans_a = node.attrs['transA'] if 'transA' in node.attrs else 0 + trans_b = node.attrs['transB'] if 'transB' in node.attrs else 0 + shape_a = self._graph_shapes_dict[node.inputs[0]] + shape_b = self._graph_shapes_dict[node.inputs[1]] + mace_check(trans_a == 0 and trans_b == 1, + "Do not support non-default transpose") + mace_check(len(shape_a) == 4, + "Unexpected fc input ndim.") + mace_check(node.inputs[1] in self._consts, "unexpect fc weight.") + if len(shape_b) == 4: + mace_check(list(shape_b[2:]) == [1, 1], + "Only support 4D weight with shape [*, *, 1, 1]") + elif len(shape_b) == 2: + tensor_b = self._consts[node.inputs[1]] + tensor_data = np.array(tensor_b.float_data).reshape( + shape_b[0], shape_b[1], 1, 1) + tensor_b.float_data[:] = tensor_data.flat + tensor_b.dims[:] = tensor_data.shape + else: + mace_check(False, "Unexpected fc weigth ndim.") + + op = self._mace_net_def.op.add() + op.name = node.name + op.type = MaceOp.FullyConnected.name + data_type_arg = op.arg.add() + data_type_arg.name = 'T' + data_type_arg.i = self._option.data_type + + framework_type_arg = op.arg.add() + framework_type_arg.name = MaceKeyword.mace_framework_type_str + framework_type_arg.i = FrameworkType.ONNX.value + + ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) + + for input in node.inputs: + op.input.append(input) + for output in node.outputs: + op.output.append(output) + output_shape = op.output_shape.add() + shape_info = self._graph_shapes_dict[output] + mace_check(len(shape_info) in [2, 4], + "gemm output shape should be 2 or 4 dims.") + if len(shape_info) == 4: + mace_check(shape_info[2] == 1 and shape_info[3] == 1, + "gemm's 4-dim output shape should be [*, * , 1, 1]") + else: + shape_info = [shape_info[0], shape_info[1], 1, 1] + output_shape.dims.extend(shape_info) + + return op diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 6248d56d39ac153fdac4f1b1d10f55311b05d65f..39bcd3b88eb86d7024f4ee9d20612d46ac6cf057 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -26,6 +26,7 @@ from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import ActivationType from mace.python.tools.converter_tool.base_converter import EltwiseType from mace.python.tools.converter_tool.base_converter import FrameworkType +from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import DataFormat from mace.python.tools.converter_tool.base_converter import FilterFormat from mace.python.tools.converter_tool.base_converter import MaceOp @@ -465,15 +466,6 @@ class TensorflowConverter(base_converter.ConverterInterface): "Mace only supports dilation == 1 conv2d_transpose.") mace_check(len(tf_op.inputs) >= 3, "deconv should have (>=) 3 inputs.") - output_shape_arg = op.arg.add() - output_shape_arg.name = MaceKeyword.mace_output_shape_str - # if tf_op.inputs[0].op.type == TFOpType.Const.name: - # output_shape_value = \ - # tf_op.inputs[0].eval().astype(np.int32).flat - # output_shape_arg.ints.extend(output_shape_value) - # else: - # output_shape_value = {} - # output_shape_arg.ints.extend(output_shape_value) del op.input[:] op.input.extend([tf_op.inputs[2].name, tf_op.inputs[1].name, @@ -810,7 +802,12 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) del op.input[1:] - op.type = MaceOp.ReduceMean.name + op.type = MaceOp.Reduce.name + + reduce_type_arg = op.arg.add() + reduce_type_arg.name = MaceKeyword.mace_reduce_type_str + reduce_type_arg.i = ReduceType.MEAN + axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str if len(tf_op.inputs) > 1: diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 0ad3dbd2ba2e86b8a088b4df685e700f4a9f5af1..98e54d0ecdcabc24bc125468220b92909e98ef17 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -352,21 +352,26 @@ class Transformer(base_converter.ConverterInterface): if elttype == EltwiseType.SQR_DIFF.value and\ self.consumer_count(op.output[0]) == 1: consumer_op = self._consumers[op.output[0]][0] - axis = ConverterUtil.get_arg( - consumer_op, - MaceKeyword.mace_axis_str).ints - keep_dims = ConverterUtil.get_arg( - consumer_op, - MaceKeyword.mace_keepdims_str).i - if consumer_op.type == MaceOp.ReduceMean.name and\ - len(consumer_op.input) == 1 and \ - axis[0] == 1 and axis[1] == 2 and keep_dims != 0: - print("Fold SquaredDiff ReduceMean: %s" % op.name) - op.type = MaceOp.SqrDiffMean.name - op.output[0] = consumer_op.output[0] - self.replace_quantize_info(op, consumer_op) - self.safe_remove_node(consumer_op, op) - return True + if consumer_op.type == MaceOp.Reduce.name: + axis = ConverterUtil.get_arg( + consumer_op, + MaceKeyword.mace_axis_str).ints + keep_dims = ConverterUtil.get_arg( + consumer_op, + MaceKeyword.mace_keepdims_str).i + reduce_type = ConverterUtil.get_arg( + consumer_op, + MaceKeyword.mace_reduce_type_str).i + if reduce_type == ReduceType.MEAN and\ + len(consumer_op.input) == 1 and\ + axis[0] == 1 and axis[1] == 2 and\ + keep_dims > 0: + print("Fold SquaredDiff Reduce: %s" % op.name) + op.type = MaceOp.SqrDiffMean.name + op.output[0] = consumer_op.output[0] + self.replace_quantize_info(op, consumer_op) + self.safe_remove_node(consumer_op, op) + return True return False @@ -1005,13 +1010,13 @@ class Transformer(base_converter.ConverterInterface): 'only support squeeze at at [2, 3]') arg.ints[:] = [1, 2] - elif op.type == MaceOp.ReduceMean.name: + elif op.type == MaceOp.Reduce.name: for arg in op.arg: if arg.name == MaceKeyword.mace_axis_str: if ConverterUtil.data_format( op) == DataFormat.NCHW \ and self._target_data_format == DataFormat.NHWC: # noqa - print("Transpose reduce mean args: %s(%s)" + print("Transpose reduce args: %s(%s)" % (op.name, op.type)) reduce_axises = list(arg.ints) new_axises = [] diff --git a/repository/opencl-kernel/opencl_kernel_configure.bzl b/repository/opencl-kernel/opencl_kernel_configure.bzl index 65cc5635712226aaed8da790b3b63deb6408d1f1..bab88f5398b02e922b9e3a03e93fd0e150635dad 100644 --- a/repository/opencl-kernel/opencl_kernel_configure.bzl +++ b/repository/opencl-kernel/opencl_kernel_configure.bzl @@ -48,7 +48,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pad.cl")) unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/pooling_buffer.cl")) - unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/reduce_mean.cl")) + unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/reduce.cl")) unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/resize_bicubic.cl")) unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/resize_bilinear.cl")) unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/split.cl")) diff --git a/tools/common.py b/tools/common.py index 2e197d8fb634be21a1e8c2320be0ac66949be165..fff51e4d080ae27f119acae1bfc80a0e7a90f526 100644 --- a/tools/common.py +++ b/tools/common.py @@ -362,6 +362,7 @@ class YAMLKeyword(object): validation_threshold = 'validation_threshold' graph_optimize_options = 'graph_optimize_options' # internal use for now cl_mem_type = 'cl_mem_type' + backend = 'backend' ################################ diff --git a/tools/converter.py b/tools/converter.py index cc1377161f4c5014fefd02114d0fa6bc493aef9d..4af30403ad82550d86ec7b2d81c2129b732cdc2c 100644 --- a/tools/converter.py +++ b/tools/converter.py @@ -55,6 +55,7 @@ ModelFormatStrs = [ PlatformTypeStrs = [ "tensorflow", "caffe", + "onnx", ] PlatformType = Enum('PlatformType', [(ele, ele) for ele in PlatformTypeStrs], type=str) @@ -469,6 +470,10 @@ def format_model_config(flags): else: subgraph[YAMLKeyword.validation_inputs_data] = \ validation_inputs_data + + onnx_backend = subgraph.get( + YAMLKeyword.backend, "tensorflow") + subgraph[YAMLKeyword.backend] = onnx_backend input_ranges = subgraph.get( YAMLKeyword.input_ranges, []) if not isinstance(input_ranges, list): diff --git a/tools/device.py b/tools/device.py index 655d90d012e65a805c2655dc0fb413ae2ed9b8df..63359dbb00b766382739ae03d579072d4f911dc5 100644 --- a/tools/device.py +++ b/tools/device.py @@ -572,7 +572,8 @@ class DeviceWrapper: YAMLKeyword.input_data_types], caffe_env=flags.caffe_env, validation_threshold=subgraphs[0][ - YAMLKeyword.validation_threshold][validate_type] + YAMLKeyword.validation_threshold][validate_type], + backend=subgraphs[0][YAMLKeyword.backend] ) if flags.report and flags.round > 0: tuned = is_tuned and device_type == DeviceType.GPU diff --git a/tools/onnx_optimizer.py b/tools/onnx_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..38b271261178895737014a535c6a9063f3824178 --- /dev/null +++ b/tools/onnx_optimizer.py @@ -0,0 +1,50 @@ +# Copyright 2018 Xiaomi, Inc. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import onnx +import sys +from onnx import optimizer + + +# Usage: python onnx_optimizer.py model.onnx model_opt.onnx + + +def main(): + if len(sys.argv) != 3: + print "Usage: python onnx_optimizer.py model.onnx model_opt.onnx" + sys.exit(0) + in_path = sys.argv[1] + out_path = sys.argv[2] + original_model = onnx.load(in_path) + print "Start optimize ONNX model for inference:" + passes = ['eliminate_identity', + 'fuse_consecutive_squeezes', + 'fuse_consecutive_transposes', + 'eliminate_nop_pad', + 'eliminate_nop_transpose', + 'eliminate_unused_initializer', + 'extract_constant_to_initializer', + 'fuse_add_bias_into_conv', + 'fuse_bn_into_conv', + 'fuse_transpose_into_gemm'] + for i in range(len(passes)): + print i, ".", passes[i] + optimized_model = optimizer.optimize(original_model, passes) + onnx.save_model(optimized_model, out_path) + print "Optimize Finished!" + print "Please check new model in:", out_path + + +if __name__ == '__main__': + main() diff --git a/tools/sh_commands.py b/tools/sh_commands.py index da8e96054d5e4e7ea24b8e95dee58a511372c23d..4be75b79fdcb345b2460e9818bdd9670c929968a 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -621,7 +621,8 @@ def validate_model(abi, caffe_env, input_file_name="model_input", output_file_name="model_out", - validation_threshold=0.9): + validation_threshold=0.9, + backend="tensorflow"): six.print_("* Validate with %s" % platform) if abi != "host": for output_name in output_nodes: @@ -638,7 +639,14 @@ def validate_model(abi, "%s/%s" % (model_output_dir, output_file_name), device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold, ",".join(input_data_types)) + validation_threshold, ",".join(input_data_types), backend) + elif platform == "onnx": + validate(platform, model_file_path, "", + "%s/%s" % (model_output_dir, input_file_name), + "%s/%s" % (model_output_dir, output_file_name), device_type, + ":".join(input_shapes), ":".join(output_shapes), + ",".join(input_nodes), ",".join(output_nodes), + validation_threshold, ",".join(input_data_types), backend) elif platform == "caffe": image_name = "mace-caffe:latest" container_name = "mace_caffe_validator" @@ -654,7 +662,7 @@ def validate_model(abi, device_type, ":".join(input_shapes), ":".join(output_shapes), ",".join(input_nodes), ",".join(output_nodes), - validation_threshold, ",".join(input_data_types)) + validation_threshold, ",".join(input_data_types), backend) elif caffe_env == common.CaffeEnvType.DOCKER: docker_image_id = sh.docker("images", "-q", image_name) if not docker_image_id: @@ -720,6 +728,7 @@ def validate_model(abi, "--output_shape=%s" % ":".join(output_shapes), "--validation_threshold=%f" % validation_threshold, "--input_data_type=%s" % ",".join(input_data_types), + "--backend=%s" % ",".join(backend), _fg=True) six.print_("Validation done!\n") diff --git a/tools/validate.py b/tools/validate.py index 3c85fb553f4329931a8178adafeb94d49ebbb4ac..139a3ee9150f0cf31ea48514403befb1fad957d6 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -21,6 +21,10 @@ import re import common +import onnx +from onnx import helper +from onnx import TensorProto + # Validation Flow: # 1. Generate input data # 2. Use mace_run to run model on phone. @@ -190,9 +194,64 @@ def validate_caffe_model(platform, device_type, model_file, input_file, value, validation_threshold) +def validate_onnx_model(platform, device_type, model_file, input_file, + mace_out_file, input_names, input_shapes, + output_names, output_shapes, validation_threshold, + input_data_types, backend): + if backend == "tensorflow": + from onnx_tf.backend import prepare + print "valivate on onnx tensorflow backend." + elif backend == "caffe2" or backend == "pytorch": + from caffe2.python.onnx.backend import prepare + print "valivate on onnx caffe2 backend." + else: + common.MaceLogger.error( + VALIDATION_MODULE, + "onnx backend framwork '" + backend + "' is invalid.") + if not os.path.isfile(model_file): + common.MaceLogger.error( + VALIDATION_MODULE, + "Input graph file '" + model_file + "' does not exist!") + model = onnx.load(model_file) + input_dict = {} + for i in range(len(input_names)): + input_value = load_data(common.formatted_file_name(input_file, + input_names[i]), + input_data_types[i]) + input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, + 2)) + input_dict[input_names[i]] = input_value + onnx_outputs = [] + for i in range(len(output_names)): + out_shape = output_shapes[i] + if len(out_shape) == 4: + out_shape[1], out_shape[2], out_shape[3] = \ + out_shape[3], out_shape[1], out_shape[2] + onnx_outputs.append( + helper.make_tensor_value_info(output_names[i], + TensorProto.FLOAT, + out_shape)) + model.graph.output.extend(onnx_outputs) + rep = prepare(model) + + output_values = rep.run(input_dict) + for i in range(len(output_names)): + out_name = output_names[i] + value = output_values[out_name].flatten() + out_shape = output_shapes[i] + if len(out_shape) == 4: + value = value.reshape(out_shape).transpose((0, 2, 3, 1)) + output_file_name = common.formatted_file_name(mace_out_file, + output_names[i]) + mace_out_value = load_data(output_file_name) + compare_output(platform, device_type, output_names[i], + mace_out_value, value, + validation_threshold) + + def validate(platform, model_file, weight_file, input_file, mace_out_file, device_type, input_shape, output_shape, input_node, output_node, - validation_threshold, input_data_type): + validation_threshold, input_data_type, backend): input_names = [name for name in input_node.split(',')] input_shape_strs = [shape for shape in input_shape.split(':')] input_shapes = [[int(x) for x in shape.split(',')] @@ -217,6 +276,15 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, mace_out_file, weight_file, input_names, input_shapes, output_names, output_shapes, validation_threshold) + elif platform == 'onnx': + output_shape_strs = [shape for shape in output_shape.split(':')] + output_shapes = [[int(x) for x in shape.split(',')] + for shape in output_shape_strs] + validate_onnx_model(platform, device_type, model_file, input_file, + mace_out_file, input_names, input_shapes, + output_names, output_shapes, + validation_threshold, + input_data_types, backend) def parse_args(): @@ -259,6 +327,11 @@ def parse_args(): parser.add_argument( "--validation_threshold", type=float, default=0.995, help="validation similarity threshold") + parser.add_argument( + "--backend", + type=str, + default="tensorflow", + help="onnx backend framwork") return parser.parse_known_args() @@ -276,4 +349,5 @@ if __name__ == '__main__': FLAGS.input_node, FLAGS.output_node, FLAGS.validation_threshold, - FLAGS.input_data_type) + FLAGS.input_data_type, + FLAGS.backend)