diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d3cbcd4ce6cccc0703c95ac6bb17b8a84f1f2cf8..3130fd697bf85fa6cb4ce7bea9571635a2bc1d5d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,6 +3,7 @@ stages: - pycodestyle - platform_compitable_tests - ops_test + - api_test - ops_benchmark - extra_tests @@ -21,7 +22,13 @@ ops_test: stage: ops_test script: - if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi - - python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS + - python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS + +api_test: + stage: api_test + script: + - if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi + - python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS ops_benchmark: stage: ops_benchmark diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 76e243b138616b5dffc3cac8c7072a6bf3e18000..04f66bac031653f2c00daece8e47320b5208eb42 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -178,6 +178,9 @@ MaceStatus MaceEngine::Impl::Run( std::vector input_tensors; std::vector output_tensors; for (auto &input : inputs) { + MACE_CHECK(input.second.shape().size() == 4, + "The Inputs' shape must be 4-dimension with NHWC format," + " please use 1 to fill missing dimensions"); Tensor *input_tensor = ws_->GetTensor(MakeString("mace_input_node_", input.first, ":0")); input_tensor->Resize(input.second.shape()); @@ -190,6 +193,11 @@ MaceStatus MaceEngine::Impl::Run( input_tensors.push_back(input_tensor); } for (auto &output : *outputs) { + if (device_type_ == DeviceType::OPENCL) { + MACE_CHECK(output.second.shape().size() == 4, + "The outputs' shape must be 4-dimension with NHWC format," + " please use 1 to fill missing dimensions"); + } Tensor *output_tensor = ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0")); output_tensors.push_back(output_tensor); diff --git a/mace/core/operator.cc b/mace/core/operator.cc index b40c29794d55397a5c17c317bd88791b7ed52a7e..1aedbe709f157b9aea61ea8e2a9dbf5e0a87a611 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -86,6 +86,7 @@ extern void Register_Conv2D(OperatorRegistry *op_registry); extern void Register_CWise(OperatorRegistry *op_registry); extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); +extern void Register_Dequantize(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FullyConnected(OperatorRegistry *op_registry); @@ -98,7 +99,9 @@ extern void Register_Pad(OperatorRegistry *op_registry); extern void Register_Pooling(OperatorRegistry *op_registry); extern void Register_Proposal(OperatorRegistry *op_registry); extern void Register_PSROIAlign(OperatorRegistry *op_registry); +extern void Register_Quantize(OperatorRegistry *op_registry); extern void Register_ReOrganize(OperatorRegistry *op_registry); +extern void Register_Requantize(OperatorRegistry *op_registry); extern void Register_Reshape(OperatorRegistry *op_registry); extern void Register_ResizeBilinear(OperatorRegistry *op_registry); extern void Register_Slice(OperatorRegistry *op_registry); @@ -124,6 +127,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_CWise(this); ops::Register_DepthToSpace(this); ops::Register_DepthwiseConv2d(this); + ops::Register_Dequantize(this); ops::Register_Eltwise(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); @@ -136,6 +140,8 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Pooling(this); ops::Register_Proposal(this); ops::Register_PSROIAlign(this); + ops::Register_Quantize(this); + ops::Register_Requantize(this); ops::Register_ReOrganize(this); ops::Register_Reshape(this); ops::Register_ResizeBilinear(this); diff --git a/mace/core/operator.h b/mace/core/operator.h index 387a41effe54afbd909154e076b6eab8618909f5..037aa1e0d0873a0c6de896c501f7ea8cc6f6f79d 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -108,12 +108,25 @@ class Operator : public OperatorBase { inputs_.push_back(tensor); } - for (const std::string &output_str : operator_def.output()) { + for (size_t i = 0; i < operator_def.output().size(); ++i) { + const std::string output_str = operator_def.output()[i]; if (ws->HasTensor(output_str)) { outputs_.push_back(ws->GetTensor(output_str)); } else { + MACE_CHECK( + operator_def.output_type().size() == 0 + || operator_def.output().size() == operator_def.output_type().size(), + "operator output size != operator output type size", + operator_def.output().size(), + operator_def.output_type().size()); + DataType output_type; + if (i < operator_def.output_type().size()) { + output_type = operator_def.output_type()[i]; + } else { + output_type = DataTypeToEnum::v(); + } outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( - output_str, GetDeviceAllocator(D), DataTypeToEnum::v()))); + output_str, GetDeviceAllocator(D), output_type))); } } } diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 0c681b14b70d2df9c81773652413b0a140513358..7a3bd994fa8baaae98a5878f92c73c0ef6ca74ae 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -81,15 +81,19 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { } VLOG(3) << "Model data size: " << model_data_size; - if (type == DeviceType::CPU || type == DeviceType::NEON) { - tensor_buffer_ = std::unique_ptr( - new Buffer(GetDeviceAllocator(type), model_data_ptr, model_data_size)); - } else { - tensor_buffer_ = std::unique_ptr( - new Buffer(GetDeviceAllocator(type), model_data_size)); - tensor_buffer_->Map(nullptr); - tensor_buffer_->Copy(model_data_ptr, 0, model_data_size); - tensor_buffer_->UnMap(); + if (model_data_size > 0) { + if (type == DeviceType::CPU || type == DeviceType::NEON) { + tensor_buffer_ = std::unique_ptr( + new Buffer(GetDeviceAllocator(type), + model_data_ptr, + model_data_size)); + } else { + tensor_buffer_ = std::unique_ptr( + new Buffer(GetDeviceAllocator(type), model_data_size)); + tensor_buffer_->Map(nullptr); + tensor_buffer_->Copy(model_data_ptr, 0, model_data_size); + tensor_buffer_->UnMap(); + } } for (auto &const_tensor : net_def.tensors()) { diff --git a/mace/examples/example.cc b/mace/examples/example.cc index aa852fdab4ece6bb053e9efbe11030ba7164fec3..52809a4fc217de50ceca7df2635836625fc7cacf 100644 --- a/mace/examples/example.cc +++ b/mace/examples/example.cc @@ -163,6 +163,8 @@ bool RunModel(const std::vector &input_names, static_cast(FLAGS_gpu_priority_hint)); } + // DO NOT USE tmp directory. + // please use APP's own directory const std::string kernel_file_path = "/data/local/tmp/mace_run/cl"; diff --git a/mace/kernels/BUILD b/mace/kernels/BUILD index 4eb4b8508409e3c1a57a28cb4a1f198409573334..50ab5c954efdb6fb2b61dd62492b9a614b8a6fc4 100644 --- a/mace/kernels/BUILD +++ b/mace/kernels/BUILD @@ -28,9 +28,12 @@ cc_library( "opencl/*.h", "arm/*.h", ]), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([ - "-DMACE_ENABLE_OPENCL", - ]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), linkopts = if_android(["-lm"]), deps = [ "//mace/core", @@ -48,9 +51,12 @@ cc_test( "opencl/*_test.cc", ], ), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([ - "-DMACE_ENABLE_OPENCL", - ]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), linkopts = ["-fopenmp"], linkstatic = 1, deps = [ diff --git a/mace/kernels/arm/conv_2d.cc b/mace/kernels/arm/conv_2d.cc index 3b99890ea7aa85646ec14fd8d73f739af7894724..6bf952e222ac6738b8564cc940ce782b5eb6468b 100644 --- a/mace/kernels/arm/conv_2d.cc +++ b/mace/kernels/arm/conv_2d.cc @@ -362,14 +362,14 @@ void Conv2dFunctor::operator()(const Tensor *input, }; } else if (use_neon_1x1_s1) { conv_func = [=](const float *pad_input, float *pad_output) { - Conv2dNeonK1x1S1(input_data, + Conv2dNeonK1x1S1(pad_input, filter_data, batch, - height, - width, + extra_input_height, + extra_input_width, input_channels, channels, - output_data); + pad_output); }; } else { conv_func = [=](const float *pad_input, float *pad_output) { diff --git a/mace/kernels/arm/fully_connected.cc b/mace/kernels/arm/fully_connected.cc index 0944480e82403bccdd4c7a8bab16d71d864a2ddd..5df39ce5825afcbeca34be911e5edee0e2babdaf 100644 --- a/mace/kernels/arm/fully_connected.cc +++ b/mace/kernels/arm/fully_connected.cc @@ -34,10 +34,10 @@ void FullyConnectedFunctordata(); float *output_ptr = output->mutable_data(); + Gemv(weight_ptr, input_ptr, N, input_size, output_size, output_ptr); for (int i = 0; i < N; ++i) { - Gemv(weight_ptr, input_ptr, input_size, output_size, output_ptr); for (int j = 0; j < output_size; ++j) { - output_ptr[j] += bias_ptr[j]; + output_ptr[j + i * output_size] += bias_ptr[j]; } } diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index cb11fa5c741e3dac8dab603ada9f973f89f623f3..b252949af135a1238f926c20174f72268530567b 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -566,6 +566,7 @@ inline void GemmTile(const float *A, } } // namespace +// A: height x K, B: K x width, C: height x width void Gemm(const float *A, const float *B, const index_t batch, @@ -573,6 +574,12 @@ void Gemm(const float *A, const index_t K, const index_t width, float *C) { + if (width == 1) { + for (index_t b = 0; b < batch; ++b) { + Gemv(A + b * height * K, B + b * K, 1, K, height, C + b * height); + } + return; + } memset(C, 0, sizeof(float) * batch * height * width); @@ -628,6 +635,7 @@ void Gemm(const float *A, } // n } +// A: height x K, B: K x width, C: height x width void GemmRef(const float *A, const float *B, const index_t height, @@ -647,19 +655,24 @@ void GemmRef(const float *A, void GemvRef(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr) { - memset(out_ptr, 0, sizeof(float) * height); - for (int h = 0; h < height; ++h) { - for (int w = 0; w < width; ++w) { - out_ptr[h] += v_ptr[w] * m_ptr[h * width + w]; + memset(out_ptr, 0, sizeof(float) * height * batch); + for (int b = 0; b < batch; ++b) { + for (int h = 0; h < height; ++h) { + for (int w = 0; w < width; ++w) { + out_ptr[h + b * height] += v_ptr[w + b * width] * m_ptr[h * width + w]; + } } } } +// M: height x width, Vin: width x 1, Vout: height x 1 void Gemv(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr) { @@ -669,88 +682,90 @@ void Gemv(const float *m_ptr, index_t remain_w = width - (width_d4 << 2); index_t remain_h = height - (height_d4 << 2); + for (index_t b = 0; b < batch; ++b) { #pragma omp parallel for - for (index_t h = 0; h < height_d4; ++h) { - const float *m_ptr0 = m_ptr + h * width * 4; - const float *m_ptr1 = m_ptr0 + width; - const float *m_ptr2 = m_ptr1 + width; - const float *m_ptr3 = m_ptr2 + width; - const float *v_ptr0 = v_ptr; - float *out_ptr0 = out_ptr + h * 4; - - float32x4_t vm0, vm1, vm2, vm3; - float32x4_t vv; - - float32x4_t vsum0 = vdupq_n_f32(0.f); - float32x4_t vsum1 = vdupq_n_f32(0.f); - float32x4_t vsum2 = vdupq_n_f32(0.f); - float32x4_t vsum3 = vdupq_n_f32(0.f); - - for (index_t w = 0; w < width_d4; ++w) { - vm0 = vld1q_f32(m_ptr0); - vm1 = vld1q_f32(m_ptr1); - vm2 = vld1q_f32(m_ptr2); - vm3 = vld1q_f32(m_ptr3); - vv = vld1q_f32(v_ptr0); - - vsum0 = vmlaq_f32(vsum0, vm0, vv); - vsum1 = vmlaq_f32(vsum1, vm1, vv); - vsum2 = vmlaq_f32(vsum2, vm2, vv); - vsum3 = vmlaq_f32(vsum3, vm3, vv); - - m_ptr0 += 4; - m_ptr1 += 4; - m_ptr2 += 4; - m_ptr3 += 4; - v_ptr0 += 4; - } - float sum0 = vaddvq_f32(vsum0); - float sum1 = vaddvq_f32(vsum1); - float sum2 = vaddvq_f32(vsum2); - float sum3 = vaddvq_f32(vsum3); - - // handle remaining w - for (index_t w = 0; w < remain_w; ++w) { - sum0 += m_ptr0[0] * v_ptr0[0]; - sum1 += m_ptr1[0] * v_ptr0[0]; - sum2 += m_ptr2[0] * v_ptr0[0]; - sum3 += m_ptr3[0] * v_ptr0[0]; - m_ptr0++; - m_ptr1++; - m_ptr2++; - m_ptr3++; - v_ptr0++; + for (index_t h = 0; h < height_d4; ++h) { + const float *m_ptr0 = m_ptr + h * width * 4; + const float *m_ptr1 = m_ptr0 + width; + const float *m_ptr2 = m_ptr1 + width; + const float *m_ptr3 = m_ptr2 + width; + const float *v_ptr0 = v_ptr + b * width; + float *out_ptr0 = out_ptr + h * 4 + b * height; + + float32x4_t vm0, vm1, vm2, vm3; + float32x4_t vv; + + float32x4_t vsum0 = vdupq_n_f32(0.f); + float32x4_t vsum1 = vdupq_n_f32(0.f); + float32x4_t vsum2 = vdupq_n_f32(0.f); + float32x4_t vsum3 = vdupq_n_f32(0.f); + + for (index_t w = 0; w < width_d4; ++w) { + vm0 = vld1q_f32(m_ptr0); + vm1 = vld1q_f32(m_ptr1); + vm2 = vld1q_f32(m_ptr2); + vm3 = vld1q_f32(m_ptr3); + vv = vld1q_f32(v_ptr0); + + vsum0 = vmlaq_f32(vsum0, vm0, vv); + vsum1 = vmlaq_f32(vsum1, vm1, vv); + vsum2 = vmlaq_f32(vsum2, vm2, vv); + vsum3 = vmlaq_f32(vsum3, vm3, vv); + + m_ptr0 += 4; + m_ptr1 += 4; + m_ptr2 += 4; + m_ptr3 += 4; + v_ptr0 += 4; + } + float sum0 = vaddvq_f32(vsum0); + float sum1 = vaddvq_f32(vsum1); + float sum2 = vaddvq_f32(vsum2); + float sum3 = vaddvq_f32(vsum3); + + // handle remaining w + for (index_t w = 0; w < remain_w; ++w) { + sum0 += m_ptr0[0] * v_ptr0[0]; + sum1 += m_ptr1[0] * v_ptr0[0]; + sum2 += m_ptr2[0] * v_ptr0[0]; + sum3 += m_ptr3[0] * v_ptr0[0]; + m_ptr0++; + m_ptr1++; + m_ptr2++; + m_ptr3++; + v_ptr0++; + } + *out_ptr0++ = sum0; + *out_ptr0++ = sum1; + *out_ptr0++ = sum2; + *out_ptr0++ = sum3; } - *out_ptr0++ = sum0; - *out_ptr0++ = sum1; - *out_ptr0++ = sum2; - *out_ptr0++ = sum3; - } - // handle remaining h - index_t remain_start_height = height_d4 << 2; + // handle remaining h + index_t remain_start_height = height_d4 << 2; #pragma omp parallel for - for (index_t h = 0; h < remain_h; ++h) { - float32x4_t vsum0 = vdupq_n_f32(0.f); - const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; - const float *v_ptr0 = v_ptr; - for (index_t w = 0; w < width_d4; ++w) { - float32x4_t vm = vld1q_f32(m_ptr0); - float32x4_t vv = vld1q_f32(v_ptr0); - vsum0 = vmlaq_f32(vsum0, vm, vv); - m_ptr0 += 4; - v_ptr0 += 4; - } - float sum = vaddvq_f32(vsum0); - for (index_t w = 0; w < remain_w; ++w) { - sum += m_ptr0[0] * v_ptr0[0]; - m_ptr0++; - v_ptr0++; + for (index_t h = 0; h < remain_h; ++h) { + float32x4_t vsum0 = vdupq_n_f32(0.f); + const float *m_ptr0 = m_ptr + (h + remain_start_height) * width; + const float *v_ptr0 = v_ptr; + for (index_t w = 0; w < width_d4; ++w) { + float32x4_t vm = vld1q_f32(m_ptr0); + float32x4_t vv = vld1q_f32(v_ptr0); + vsum0 = vmlaq_f32(vsum0, vm, vv); + m_ptr0 += 4; + v_ptr0 += 4; + } + float sum = vaddvq_f32(vsum0); + for (index_t w = 0; w < remain_w; ++w) { + sum += m_ptr0[0] * v_ptr0[0]; + m_ptr0++; + v_ptr0++; + } + out_ptr[remain_start_height + h] = sum; } - out_ptr[remain_start_height + h] = sum; } #else - GemvRef(m_ptr, v_ptr, width, height, out_ptr); + GemvRef(m_ptr, v_ptr, batch, width, height, out_ptr); #endif } diff --git a/mace/kernels/gemm.h b/mace/kernels/gemm.h index ba4f812d74f11adbbbeba384b3b04c9746ed2309..e1fcfad6a51b9e4611e538709818cf5126311a5a 100644 --- a/mace/kernels/gemm.h +++ b/mace/kernels/gemm.h @@ -41,12 +41,14 @@ void GemmRef(const float *A, void Gemv(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr); void GemvRef(const float *m_ptr, const float *v_ptr, + const index_t batch, const index_t width, const index_t height, float *out_ptr); diff --git a/mace/kernels/gemm_test.cc b/mace/kernels/gemm_test.cc index 217543ed9fd38cca186808471a7a652343f1a956..8400ca857dade6212f96ecfd4ff17b48316bbc81 100644 --- a/mace/kernels/gemm_test.cc +++ b/mace/kernels/gemm_test.cc @@ -70,8 +70,8 @@ TEST(GEMMTest, gemv) { [&gen, &nd] { return nd(gen); }); - kernels::Gemv(A.get(), B.get(), K, N, C.get()); - kernels::GemvRef(A.get(), B.get(), K, N, C_ref.get()); + kernels::Gemv(A.get(), B.get(), 1, K, N, C.get()); + kernels::GemvRef(A.get(), B.get(), 1, K, N, C_ref.get()); for (int i = 0; i < N; ++i) { EXPECT_NEAR(C_ref[i], C[i], 0.1); diff --git a/mace/kernels/quantize.h b/mace/kernels/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..1ffab4880df718266dc0fdc7cea7dd64fdc362da --- /dev/null +++ b/mace/kernels/quantize.h @@ -0,0 +1,195 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_KERNELS_QUANTIZE_H_ +#define MACE_KERNELS_QUANTIZE_H_ + +#include +#include +#include + +#include "mace/core/future.h" +#include "mace/core/tensor.h" + +namespace mace { +namespace kernels { + +template +inline void AdjustRange(const float in_min_data, + const float in_max_data, + float *out_min_data, + float *out_max_data) { + // re-range to make range include zero float and + // make zero float as integer u8 + const float quantized_max = std::numeric_limits::max(); + float out_min = fminf(0.f, in_min_data); + float out_max = fmaxf(0.f, in_max_data); + if (out_min < 0.f) { + float stepsize = (in_max_data - in_min_data) / quantized_max; + float quantized_zero = -in_min_data / stepsize; + float quantized_zero_near_int = roundf(quantized_zero); + if (fabs(quantized_zero - quantized_zero_near_int) > 1e-6) { + if (quantized_zero < quantized_zero_near_int) { + // keep out_max fixed, and move out_min + stepsize = out_max / (quantized_max - quantized_zero_near_int); + out_min = out_max - quantized_max * stepsize; + } else { + // keep out_min fixed, and move out_max + stepsize = -out_min / quantized_zero_near_int; + out_max = out_min + quantized_max * stepsize; + } + } + } + *out_min_data = out_min; + *out_max_data = out_max; +} + +template +inline T Saturate(float value) { + int rounded_value = static_cast(value); + if (rounded_value <= std::numeric_limits::lowest()) { + return std::numeric_limits::lowest(); + } else if (rounded_value >= std::numeric_limits::max()) { + return std::numeric_limits::max(); + } else { + return static_cast(rounded_value); + } +} + +template +struct QuantizeFunctor; + +template<> +struct QuantizeFunctor { + QuantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + Tensor *output, + Tensor *out_min, + Tensor *out_max, + StatsFuture *future) { + const float *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + uint8_t *output_data = output->mutable_data(); + float *out_min_data = out_min->mutable_data(); + float *out_max_data = out_max->mutable_data(); + + AdjustRange(in_min_data, in_max_data, out_min_data, out_max_data); + float recip_stepsize = 255.f / (out_max_data[0] - out_min_data[0]); + for (int i = 0; i < input->size(); ++i) { + output_data[i] = Saturate(roundf( + (input_data[i] - in_min_data) * recip_stepsize)); + } + } +}; + +template +struct DequantizeFunctor; + +template<> +struct DequantizeFunctor { + DequantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + Tensor *output, + StatsFuture *future) { + const uint8_t *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + float *output_data = output->mutable_data(); + + float stepsize = (in_max_data - in_min_data) / 255.0; + for (int i = 0; i < input->size(); ++i) { + output_data[i] = in_min_data + stepsize * input_data[i]; + } + } +}; + +template +struct RequantizeFunctor; + +template<> +struct RequantizeFunctor { + RequantizeFunctor() {} + + void operator()(const Tensor *input, + const Tensor *in_min, + const Tensor *in_max, + const Tensor *rerange_min, + const Tensor *rerange_max, + Tensor *output, + Tensor *out_min, + Tensor *out_max, + StatsFuture *future) { + const int *input_data = input->data(); + const float in_min_data = in_min->data()[0]; + const float in_max_data = in_max->data()[0]; + + float rerange_min_data; + float rerange_max_data; + int min_val = std::numeric_limits::max(); + int max_val = std::numeric_limits::lowest(); + double + si = (in_max_data - in_min_data) / std::numeric_limits::max(); + if (rerange_min == nullptr && rerange_max == nullptr) { + for (int i = 0; i < input->size(); ++i) { + min_val = std::min(min_val, input_data[i]); + max_val = std::max(max_val, input_data[i]); + } + rerange_min_data = min_val * si; + rerange_max_data = max_val * si; + } else { + rerange_min_data = rerange_min->data()[0]; + rerange_max_data = rerange_max->data()[0]; + } + + uint8_t *output_data = output->mutable_data(); + float *out_min_data = out_min->mutable_data(); + float *out_max_data = out_max->mutable_data(); + + AdjustRange(rerange_min_data, + rerange_max_data, + out_min_data, + out_max_data); + /** + * f = qi * si = min_o + qo * so + * => qo = (qi * si - min_o) / so + * = qi * (si/so) - min_o / so + * = qi * (si / so) + zo + * + * zo = -min_o / so + * + */ + float so = + (out_max_data[0] - out_min_data[0]) / std::numeric_limits::max(); + double step_ratio = si / so; + float quantized_out_zero = -out_min_data[0] / so; + + for (int i = 0; i < output->size(); ++i) { + output_data[i] = + Saturate(roundf( + quantized_out_zero + input_data[i] * step_ratio)); + } + } +}; + +} // namespace kernels +} // namespace mace + +#endif // MACE_KERNELS_QUANTIZE_H_ diff --git a/mace/kernels/transpose.h b/mace/kernels/transpose.h index 6854c5f91e69bd06e342eefe70b7f47d09ad4ffe..b5e029ed4e346b1487ac018e76ebb1356b23492d 100644 --- a/mace/kernels/transpose.h +++ b/mace/kernels/transpose.h @@ -37,31 +37,44 @@ struct TransposeFunctor { const T *input_data = input->data(); T *output_data = output->mutable_data(); - std::vector - in_stride{input_shape[1] * input_shape[2] * input_shape[3], - input_shape[2] * input_shape[3], input_shape[3], 1}; - std::vector - out_stride{output_shape[1] * output_shape[2] * output_shape[3], - output_shape[2] * output_shape[3], output_shape[3], 1}; + if (input->dim_size() == 2) { + MACE_CHECK(dims_[0] == 1 && dims_[1] == 0, "no need transform"); + index_t stride_i = input_shape[0]; + index_t stride_j = input_shape[1]; + for (int i = 0; i < input_shape[0]; ++i) { + for (int j = 0; j < input_shape[1]; ++j) { + output_data[j * stride_i + i] = input_data[i * stride_j + j]; + } + } + } else if (input->dim_size() == 4) { + std::vector + in_stride{input_shape[1] * input_shape[2] * input_shape[3], + input_shape[2] * input_shape[3], input_shape[3], 1}; + std::vector + out_stride{output_shape[1] * output_shape[2] * output_shape[3], + output_shape[2] * output_shape[3], output_shape[3], 1}; - std::vector idim(4, 0); - std::vector odim(4, 0); - for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { - for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { - for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { - for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { - idim[dims_[0]] = odim[0]; - idim[dims_[1]] = odim[1]; - idim[dims_[2]] = odim[2]; - idim[dims_[3]] = odim[3]; + std::vector idim(4, 0); + std::vector odim(4, 0); + for (odim[0] = 0; odim[0] < output_shape[0]; ++odim[0]) { + for (odim[1] = 0; odim[1] < output_shape[1]; ++odim[1]) { + for (odim[2] = 0; odim[2] < output_shape[2]; ++odim[2]) { + for (odim[3] = 0; odim[3] < output_shape[3]; ++odim[3]) { + idim[dims_[0]] = odim[0]; + idim[dims_[1]] = odim[1]; + idim[dims_[2]] = odim[2]; + idim[dims_[3]] = odim[3]; - output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] - + odim[2] * out_stride[2] + odim[3]] = - input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] - + idim[2] * in_stride[2] + idim[3]]; + output_data[odim[0] * out_stride[0] + odim[1] * out_stride[1] + + odim[2] * out_stride[2] + odim[3]] = + input_data[idim[0] * in_stride[0] + idim[1] * in_stride[1] + + idim[2] * in_stride[2] + idim[3]]; + } } } } + } else { + MACE_NOT_IMPLEMENTED; } } diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 131beceb222f34accb201e731205cb8b03425718..ba39f5af4e2caef921ea49184b9880abf1a14715 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -34,9 +34,12 @@ cc_library( ["*.h"], exclude = ["ops_test_util.h"], ), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([ - "-DMACE_ENABLE_OPENCL", - ]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), deps = [ "//mace/kernels", ], @@ -49,9 +52,12 @@ cc_test( srcs = glob( ["*_test.cc"], ), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([ - "-DMACE_ENABLE_OPENCL", - ]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), linkopts = ["-fopenmp"], linkstatic = 1, deps = [ @@ -65,9 +71,12 @@ cc_test( name = "ops_benchmark", testonly = 1, srcs = glob(["*_benchmark.cc"]), - copts = if_openmp_enabled(["-fopenmp"]) + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + if_android_armv7(["-mfpu=neon -mfloat-abi=softfp"]) + if_android([ - "-DMACE_ENABLE_OPENCL", - ]) + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), linkopts = ["-fopenmp"], linkstatic = 1, deps = [ diff --git a/mace/ops/batch_norm_test.cc b/mace/ops/batch_norm_test.cc index c1f0ca02dbe59510682f215ee194b3901db6243a..7f1bb0379865e8bb68b0fb86ceabcbea2090a4e0 100644 --- a/mace/ops/batch_norm_test.cc +++ b/mace/ops/batch_norm_test.cc @@ -429,7 +429,7 @@ TEST_F(BatchNormOpTest, NEONTest) { ExpectTensorNear(*net.GetOutput("OutputExptected"), *net.GetOutput("OutputNeon"), - 1e-5); + 1e-5, 1e-4); } } // namespace test diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 219e4af343a34951f80c2838f8ba6d7ee6b17355..41a6546a3f3366ee2a406a04c3dd6039d7ee2192 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -826,7 +826,7 @@ static void TestNeonArbitraryPadConvNxN(const std::vector &shape, for (int kernel_size : {1, 3, 5}) { for (int stride : {1, 2}) { - if (stride < kernel_size) { + if (stride <= kernel_size) { func(kernel_size, kernel_size, stride, stride); } } diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index e994213a5754ba3cf350106bd9a3cca4fa6f64ab..daef740297092bc37789effa08cbc225eacaf908 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -337,6 +337,8 @@ TEST_F(FullyConnectedOpTest, TestNEON) { FullyConnectedTestNEON(1, 7, 7, 32, 16); FullyConnectedTestNEON(1, 7, 7, 512, 128); FullyConnectedTestNEON(1, 1, 1, 2048, 1024); + FullyConnectedTestNEON(3, 1, 1, 16, 8); + FullyConnectedTestNEON(3, 7, 7, 32, 16); } } // namespace test diff --git a/mace/ops/fused_conv_2d_test.cc b/mace/ops/fused_conv_2d_test.cc index d02953b2b608cf882b8f9e925ea0c722bc28f1e4..afe889bedc9437782d830a0e3dddfeb13f380ef1 100644 --- a/mace/ops/fused_conv_2d_test.cc +++ b/mace/ops/fused_conv_2d_test.cc @@ -375,90 +375,92 @@ TEST_F(FusedConv2dOpTest, OPENCLUnalignedConvNxNS12) { namespace { template -void TestHalfComplexConvNxNS12(const std::vector &shape) { +void TestHalfComplexConvNxNS12(const std::vector &shape, + const int kernel, const int stride, + Padding type) { testing::internal::LogToStderr(); - auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, - Padding type) { - // generate random input - static unsigned int seed = time(NULL); - index_t batch = 3 + (rand_r(&seed) % 10); - index_t height = shape[0]; - index_t width = shape[1]; - index_t input_channels = shape[2] + (rand_r(&seed) % 10); - index_t output_channels = shape[3] + (rand_r(&seed) % 10); - // Construct graph - OpsTestNet net; - OpDefBuilder("FusedConv2D", "FusedConv2dTest") + // generate random input + srand(time(NULL)); + index_t batch = 3; + index_t height = shape[0]; + index_t width = shape[1]; + index_t input_channels = shape[2]; + index_t output_channels = shape[3]; + // Construct graph + OpsTestNet net; + OpDefBuilder("FusedConv2D", "FusedConv2dTest") .Input("Input") .Input("Filter") .Input("Bias") .Output("Output") - .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntsArg("strides", {stride, stride}) .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) .Finalize(net.NewOperatorDef()); - std::vector float_input_data; - GenerateRandomRealTypeData({batch, height, width, input_channels}, - &float_input_data); - std::vector float_filter_data; - GenerateRandomRealTypeData( - {kernel_h, kernel_w, output_channels, input_channels}, + std::vector float_input_data; + GenerateRandomRealTypeData({batch, height, width, input_channels}, + &float_input_data); + std::vector float_filter_data; + GenerateRandomRealTypeData( + {kernel, kernel, output_channels, input_channels}, &float_filter_data); - std::vector float_bias_data; - GenerateRandomRealTypeData({output_channels}, &float_bias_data); - // Add input data - net.AddInputFromArray( + std::vector float_bias_data; + GenerateRandomRealTypeData({output_channels}, &float_bias_data); + // Add input data + net.AddInputFromArray( "Input", {batch, height, width, input_channels}, float_input_data); - net.AddInputFromArray( - "Filter", {kernel_h, kernel_w, output_channels, input_channels}, + net.AddInputFromArray( + "Filter", {kernel, kernel, output_channels, input_channels}, float_filter_data); - net.AddInputFromArray("Bias", {output_channels}, float_bias_data); - - // run on cpu - net.RunOp(); - // Check - Tensor expected; - expected.Copy(*net.GetOutput("Output")); + net.AddInputFromArray("Bias", {output_channels}, float_bias_data); - // run on gpu - BufferToImage(&net, "Input", "InputImage", - kernels::BufferType::IN_OUT_CHANNEL); - BufferToImage(&net, "Filter", "FilterImage", - kernels::BufferType::CONV2D_FILTER); - BufferToImage(&net, "Bias", "BiasImage", - kernels::BufferType::ARGUMENT); - - OpDefBuilder("FusedConv2D", "FusedConv2dTest") + // run on cpu + net.RunOp(); + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run on gpu + BufferToImage(&net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(&net, "Filter", "FilterImage", + kernels::BufferType::CONV2D_FILTER); + BufferToImage(&net, "Bias", "BiasImage", + kernels::BufferType::ARGUMENT); + + OpDefBuilder("FusedConv2D", "FusedConv2dTest") .Input("InputImage") .Input("FilterImage") .Input("BiasImage") .Output("OutputImage") - .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntsArg("strides", {stride, stride}) .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) .AddIntArg("T", static_cast(DataType::DT_HALF)) .Finalize(net.NewOperatorDef()); - // Run on device - net.RunOp(D); + // Run on device + net.RunOp(D); - ImageToBuffer(&net, "OutputImage", "OPENCLOutput", - kernels::BufferType::IN_OUT_CHANNEL); - - ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), - 1e-2, 1e-1); - }; + ImageToBuffer(&net, "OutputImage", "OPENCLOutput", + kernels::BufferType::IN_OUT_CHANNEL); - for (int kernel_size : {1, 3}) { - for (int stride : {1, 2}) { - func(kernel_size, kernel_size, stride, stride, VALID); - } - } + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), + 1e-2, 1e-1); } } // namespace -TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConvNxNS12) { - TestHalfComplexConvNxNS12({32, 32, 32, 64}); +TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConv1x1S12) { + TestHalfComplexConvNxNS12({32, 32, 32, 64}, 1, 1, VALID); + TestHalfComplexConvNxNS12({31, 37, 31, 37}, 1, 1, SAME); + TestHalfComplexConvNxNS12({32, 32, 32, 64}, 1, 2, VALID); + TestHalfComplexConvNxNS12({31, 37, 31, 37}, 1, 2, SAME); +} +TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConv3x3S12) { + TestHalfComplexConvNxNS12({32, 32, 32, 64}, 3, 1, VALID); + TestHalfComplexConvNxNS12({31, 37, 31, 37}, 3, 1, SAME); + TestHalfComplexConvNxNS12({32, 32, 32, 64}, 3, 2, VALID); + TestHalfComplexConvNxNS12({31, 37, 31, 37}, 3, 2, SAME); } namespace { diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 3235c9027f16ffa0b250beaf9f64073f620ace78..09b2fa109d1aa25d6c7fcaeaa44dffcead3cab99 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -52,6 +52,11 @@ class OpDefBuilder { return *this; } + OpDefBuilder &OutputType(const std::vector &output_type) { + op_def_.set_output_type(output_type); + return *this; + } + OpDefBuilder AddIntArg(const std::string &name, const int value) { auto arg = op_def_.add_arg(); arg->set_name(name); @@ -283,6 +288,16 @@ class OpsTestNet { return RunOp(DeviceType::CPU); } + bool RunNet(const NetDef &net_def, const DeviceType device) { + device_ = device; + net_ = CreateNet(op_registry_, net_def, &ws_, device, NetMode::INIT); + if (!net_->Run()) { + return false; + } + net_ = CreateNet(op_registry_, net_def, &ws_, device); + return net_->Run(); + } + Tensor *GetOutput(const char *output_name) { return ws_.GetTensor(output_name); } @@ -451,7 +466,7 @@ struct Expector { auto a = x.data(); auto b = y.data(); for (int i = 0; i < x.size(); ++i) { - ExpectEqual(a(i), b(i)); + ExpectEqual(a[i], b[i]); } } @@ -489,12 +504,35 @@ struct Expector { } }; +template +struct Expector { + static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); } + + static void Equal(const Tensor &x, const Tensor &y) { + ASSERT_EQ(x.dtype(), DataTypeToEnum::v()); + ASSERT_EQ(y.dtype(), DataTypeToEnum::v()); + AssertSameDims(x, y); + Tensor::MappingGuard x_mapper(&x); + Tensor::MappingGuard y_mapper(&y); + auto a = x.data(); + auto b = y.data(); + for (int i = 0; i < x.size(); ++i) { + ExpectEqual(a[i], b[i]); + } + } + + static void Near(const Tensor &x, const Tensor &y, + const double rel_err, + const double abs_err) { + Equal(x, y); + } +}; + + template void ExpectTensorNear(const Tensor &x, const Tensor &y, const double rel_err = 1e-5, const double abs_err = 1e-8) { - static_assert(is_floating_point_type::value, - "T is not a floating point type"); Expector::Near(x, y, rel_err, abs_err); } @@ -502,9 +540,6 @@ template void ExpectTensorNear(const Tensor &x, const Tensor &y, const double rel_err = 1e-5, const double abs_err = 1e-8) { - static_assert(is_floating_point_type::value && - is_floating_point_type::value, - "T is not a floating point type"); Expector::Near(x, y, rel_err, abs_err); } diff --git a/mace/ops/quantize.cc b/mace/ops/quantize.cc new file mode 100644 index 0000000000000000000000000000000000000000..49695fde74fa375bafeb117bfebe0306d5c4e5b0 --- /dev/null +++ b/mace/ops/quantize.cc @@ -0,0 +1,60 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/quantize.h" + +namespace mace { +namespace ops { + +void Register_Quantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + QuantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Quantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + QuantizeOp); +} + +void Register_Dequantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DequantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Dequantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + DequantizeOp); +} + +void Register_Requantize(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + RequantizeOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("Requantize") + .Device(DeviceType::NEON) + .TypeConstraint("T") + .Build(), + RequantizeOp); +} + +} // namespace ops +} // namespace mace diff --git a/mace/ops/quantize.h b/mace/ops/quantize.h new file mode 100644 index 0000000000000000000000000000000000000000..cee215f1ab5e764787cf612c6900979b3de4ad4e --- /dev/null +++ b/mace/ops/quantize.h @@ -0,0 +1,144 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MACE_OPS_QUANTIZE_H_ +#define MACE_OPS_QUANTIZE_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/quantize.h" + +namespace mace { +namespace ops { + +template +class QuantizeOp : public Operator { + public: + QuantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + Tensor *output = this->Output(OUTPUT); + Tensor *out_min = this->Output(OUT_MIN); + Tensor *out_max = this->Output(OUT_MAX); + output->ResizeLike(input); + out_min->ResizeLike(in_min); + out_max->ResizeLike(in_max); + + functor_(input, in_min, in_max, output, out_min, out_max, future); + return true; + } + + private: + kernels::QuantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX); + OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX); +}; + +template +class DequantizeOp : public Operator { + public: + DequantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + Tensor *output = this->Output(OUTPUT); + output->ResizeLike(input); + + functor_(input, in_min, in_max, output, future); + return true; + } + + private: + kernels::DequantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX); + OP_OUTPUT_TAGS(OUTPUT); +}; + +template +class RequantizeOp : public Operator { + public: + RequantizeOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) { + } + + bool Run(StatsFuture *future) override { + const Tensor *input = this->Input(INPUT); + const Tensor *in_min = this->Input(IN_MIN); + const Tensor *in_max = this->Input(IN_MAX); + const Tensor *rerange_min = nullptr; + const Tensor *rerange_max = nullptr; + + MACE_CHECK(in_min->size() == 1, "min val tensor has more than 1 value"); + MACE_CHECK(in_max->size() == 1, "max val tensor has more than 1 value"); + + if (this->InputSize() >= 5) { + rerange_min = this->Input(RERANGE_MIN); + rerange_max = this->Input(RERANGE_MAX); + MACE_CHECK(rerange_min->size() == 1, + "rerange min val tensor has more than 1 value"); + MACE_CHECK(rerange_max->size() == 1, + "rerange max val tensor has more than 1 value"); + } + + Tensor *output = this->Output(OUTPUT); + Tensor *out_min = this->Output(OUT_MIN); + Tensor *out_max = this->Output(OUT_MAX); + output->ResizeLike(input); + out_min->ResizeLike(in_min); + out_max->ResizeLike(out_max); + + functor_(input, + in_min, + in_max, + rerange_min, + rerange_max, + output, + out_min, + out_max, + future); + return true; + } + + private: + kernels::RequantizeFunctor functor_; + + protected: + OP_INPUT_TAGS(INPUT, IN_MIN, IN_MAX, RERANGE_MIN, RERANGE_MAX); + OP_OUTPUT_TAGS(OUTPUT, OUT_MIN, OUT_MAX); +}; + +} // namespace ops +} // namespace mace + +#endif // MACE_OPS_QUANTIZE_H_ diff --git a/mace/ops/quantize_test.cc b/mace/ops/quantize_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1672ac53d9d1df71f7a53b939aec76c409404e45 --- /dev/null +++ b/mace/ops/quantize_test.cc @@ -0,0 +1,224 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/core/operator.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +class QuantizeTest : public OpsTestBase {}; + +TEST_F(QuantizeTest, TestQuantize) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -2, -1, 1, 2, 3, 4 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Quantize", "QuantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto output_min = net.GetTensor("OutputMin"); + auto output_max = net.GetTensor("OutputMax"); + + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 32, 64, 127, 159, 191, 223 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output); + ExpectTensorNear(*expected_min, *output_min); + ExpectTensorNear(*expected_max, *output_max); +} + +TEST_F(QuantizeTest, TestQuantizeTrend) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", {100}); + const float *input_data = net.GetTensor("Input")->data(); + net.AddInputFromArray("InputMin", + {1}, + {*std::min_element(input_data, + input_data + + net.GetTensor("Input")->size())}); + net.AddInputFromArray("InputMax", + {1}, + {*std::max_element(input_data, + input_data + + net.GetTensor("Input")->size())}); + + OpDefBuilder("Quantize", "QuantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto output_min = net.GetTensor("OutputMin"); + auto output_max = net.GetTensor("OutputMax"); + + const uint8_t *output_data = net.GetTensor("Output")->data(); + for (int i = 1; i < output->size(); ++i) { + if (input_data[i] > input_data[i - 1]) { + EXPECT_GE(output_data[i], output_data[i - 1]); + } else if (input_data[i] == input_data[i - 1]) { + EXPECT_EQ(output_data[i], output_data[i - 1]); + } else { + EXPECT_LE(output_data[i], output_data[i - 1]); + } + } +} + +TEST_F(QuantizeTest, TestDequantize) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + 32, 64, 127, 159, 191, 223 + }); + net.AddInputFromArray("InputMin", {1}, {-3.01887}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Dequantize", "DequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .OutputType({DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + -2, -1, 1, 2, 3, 4 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output, 0.1, 0.01); +} + +TEST_F(QuantizeTest, TestRequantizeWithMinMax) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + net.AddInputFromArray("RerangeMin", {1}, {-3.01887}); + net.AddInputFromArray("RerangeMax", {1}, {5}); + + OpDefBuilder("Requantize", "RequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Input("RerangeMin") + .Input("RerangeMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 32, 64, 128, 160, 191, 223 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + + ExpectTensorNear(*expected_output, *output); +} + +TEST_F(QuantizeTest, TestRequantizeWithoutMinMax) { + // Construct graph + OpsTestNet net; + + // Add input data + net.AddInputFromArray("Input", {1, 2, 3, 1}, { + -1073741824, -536870912, 536870912, 1073741824, 1610612736, 2147483647 + }); + net.AddInputFromArray("InputMin", {1}, {-3}); + net.AddInputFromArray("InputMax", {1}, {5}); + + OpDefBuilder("Requantize", "RequantizeTest") + .Input("Input") + .Input("InputMin") + .Input("InputMax") + .Output("Output") + .Output("OutputMin") + .Output("OutputMax") + .OutputType({DT_UINT8, DT_FLOAT, DT_FLOAT}) + .AddIntArg("T", DT_UINT8) + .Finalize(net.NewOperatorDef()); + + // Run + net.RunOp(); + + auto output = net.GetTensor("Output"); + auto expected_output = CreateTensor({1, 2, 3, 1}, + { + 0, 43, 128, 170, 213, 255 + }); + auto expected_min = CreateTensor({1}, {-3.01887}); + auto expected_max = CreateTensor({1}, {5}); + ExpectTensorNear(*expected_output, *output); +} + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/transpose.h b/mace/ops/transpose.h index 45e36fa353a5acdcc66418bf06ffe385f1a9e741..2ec9281c41021634f2e77c69eee4dcf97626fc2d 100644 --- a/mace/ops/transpose.h +++ b/mace/ops/transpose.h @@ -28,16 +28,16 @@ class TransposeOp : public Operator { public: TransposeOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws), - dims_(OperatorBase::GetRepeatedArgument( - "dims")), + dims_(OperatorBase::GetRepeatedArgument("dims")), functor_(dims_) {} bool Run(StatsFuture *future) override { const Tensor *input = this->Input(INPUT); Tensor *output = this->Output(OUTPUT); const std::vector &input_shape = input->shape(); - MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4, - "rank should be 4"); + MACE_CHECK(input_shape.size() == 4 && dims_.size() == 4 + || input_shape.size() == 2 && dims_.size() == 2, + "rank should be 2 or 4"); std::vector output_shape; for (int i = 0; i < dims_.size(); ++i) { output_shape.push_back(input_shape[dims_[i]]); diff --git a/mace/ops/transpose_benchmark.cc b/mace/ops/transpose_benchmark.cc new file mode 100644 index 0000000000000000000000000000000000000000..a86549ed9cc4206b00d9276df524e95d491acad7 --- /dev/null +++ b/mace/ops/transpose_benchmark.cc @@ -0,0 +1,93 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "mace/core/operator.h" +#include "mace/core/testing/test_benchmark.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace ops { +namespace test { + +namespace { +template +void TransposeBenchmark(int iters, + std::vector shape, + std::vector dims) { + mace::testing::StopTiming(); + + OpsTestNet net; + + // Add input data + net.AddRandomInput("Input", shape); + + OpDefBuilder("Transpose", "TransposeBM") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", dims) + .Finalize(net.NewOperatorDef()); + + // Warm-up + for (int i = 0; i < 5; ++i) { + net.RunOp(D); + } + net.Sync(); + + mace::testing::StartTiming(); + while (iters--) { + net.RunOp(D); + } + net.Sync(); +} +} // namespace + +#define BM_TRANSPOSE2D_MACRO(H, W, TYPE, DEVICE) \ + static void BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TransposeBenchmark(iters, {H, W}, {1, 0}); \ + } \ + BENCHMARK(BM_TRANSPOSE2D_##H##_##W##_##TYPE##_##DEVICE) + +#define BM_TRANSPOSE2D(H, W) \ + BM_TRANSPOSE2D_MACRO(H, W, float, CPU); + +#define BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, TYPE, DEVICE) \ + static void \ + BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE( \ + int iters) { \ + const int64_t tot = static_cast(iters) * N * C * H * W; \ + mace::testing::MaccProcessed(tot); \ + mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ + TransposeBenchmark(iters, {N, C, H, W}, {D0, D1, D2, D3}); \ + } \ + BENCHMARK( \ + BM_TRANSPOSE4D_##N##_##C##_##H##_##W##_##D0##D1##D2##D3##_##TYPE##_##DEVICE) + +#define BM_TRANSPOSE4D(N, C, H, W, D0, D1, D2, D3) \ + BM_TRANSPOSE4D_MACRO(N, C, H, W, D0, D1, D2, D3, float, CPU); + +BM_TRANSPOSE4D(1, 64, 64, 512, 0, 3, 1, 2); +BM_TRANSPOSE4D(1, 512, 64, 64, 0, 2, 3, 1); +BM_TRANSPOSE2D(128, 128); +BM_TRANSPOSE2D(512, 512); + +} // namespace test +} // namespace ops +} // namespace mace diff --git a/mace/ops/transpose_test.cc b/mace/ops/transpose_test.cc index b1e8cd4e040f63aedd63087bad1c184a7a8fb160..0faacc9111c4e904a6bd2a95b44b835376ae9987 100644 --- a/mace/ops/transpose_test.cc +++ b/mace/ops/transpose_test.cc @@ -49,6 +49,29 @@ TEST_F(TransposeOpTest, NCHW) { TransposeNCHWTest({1, 64, 48, 128}); } +TEST_F(TransposeOpTest, Rank2) { + // Construct graph + OpsTestNet net; + // Add input data + net.AddInputFromArray("Input", {2, 3}, {1, 2, 3, 4, 5, 6}); + + OpDefBuilder("Transpose", "TransposeNCHWTest") + .Input("Input") + .Output("Output") + .AddIntsArg("dims", {1, 0}) + .Finalize(net.NewOperatorDef()); + + // Run on cpu + net.RunOp(); + + net.AddInputFromArray("ExpectedOutput", + {3, 2}, + {1, 4, 2, 5, 3, 6}); + + ExpectTensorNear(*net.GetOutput("ExpectedOutput"), + *net.GetOutput("Output")); +} + } // namespace test } // namespace ops } // namespace mace diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index 8246c2490135413340c133f37f60de36433865ce..070c75a123359bcbbcdc76d04b584d5f7d67d6dc 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -320,6 +320,13 @@ class CaffeConverter(object): arg.name = 'T' arg.i = self.dt + input_op = self.ops_map[name] + if input_op.layer is not None: + output_shape = input_op.output_shape_map[input_op.layer.top[0]] + else: + output_shape = input_op.output_shape_map[input_op.name] + self.add_output_shape(op_def, output_shape) + def add_output_transform(self, names): for name in names: output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0" @@ -1091,15 +1098,15 @@ class CaffeConverter(object): dims_arg.ints.extend([0, 2, 3, 1]) # NCHW -> NHWC def convert(self, input_nodes, input_shapes, output_nodes): + assert self.ops[0].type == 'Input' + self.add_input_op_shape(input_nodes, input_shapes) + if self.device == 'gpu': self.add_input_transform(input_nodes) if self.device == 'neon': self.add_neon_input_transform(input_nodes) - assert self.ops[0].type == 'Input' - self.add_input_op_shape(input_nodes, input_shapes) - for op in self.ops: if op.name in self.resolved_ops: continue diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index 4eef9395f8648245b37f270bcf2c51faac84f888..fddb50e276d9f23f00ced9b666681467585283ee 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -46,7 +46,11 @@ class MemoryOptimizer(object): self.ref_counter[tensor_name] = 0 def is_buffer_image_op(self, op): - return op.type == 'BufferToImage' or op.type == 'ImageToBuffer' + if op.type == 'BufferToImage': + for arg in op.arg: + if arg.name == 'mode' and arg.i == 0: + return True + return op.type == 'ImageToBuffer' def get_mem_size(self, op_type, output_shape): mem_size = [0, 0] diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index c1f87781c515537339adc98403b0105e948c537d..4079c953de0d26544cd3b63f060792d968b49b3d 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -155,6 +155,8 @@ class TFConverter(object): arg.name = 'T' arg.i = self.dt + self.add_output_shape(self.ops[name].outputs, op_def) + def add_neon_input_transform(self, names): for name in names: new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0" diff --git a/mace/test/BUILD b/mace/test/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f3345cfac12dedb66cbe9f7c4a5d02a120e2113e --- /dev/null +++ b/mace/test/BUILD @@ -0,0 +1,30 @@ +# Description: +# Mace operators. +# +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//mace:mace.bzl", "if_android", "if_neon_enabled", "if_openmp_enabled", "if_android_armv7", "if_hexagon_enabled") + +cc_test( + name = "mace_api_test", + testonly = 1, + srcs = ["mace_api_test.cc"], + copts = if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + "//mace/ops:test", + "//mace/kernels:kernels", + "//mace/ops:ops", + "@gtest//:gtest_main", + ], +) diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc22a450edfc5a5971d6717e4db01c2bf2dc96ad --- /dev/null +++ b/mace/test/mace_api_test.cc @@ -0,0 +1,336 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace test { + +class MaceAPITest : public ::testing::Test {}; + +namespace { + +void GenerateInputs(const std::vector &input_names, + const std::vector &input_shape, + std::map *inputs) { + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + // Allocate input and output + int64_t input_size = + std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + auto buffer_in = std::shared_ptr(new float[input_size], + std::default_delete()); + // load input + std::vector input_data; + ops::test::GenerateRandomRealTypeData(input_shape, &input_data); + memcpy(buffer_in.get(), input_data.data(), input_size * sizeof(float)); + (*inputs)[input_names[i]] = mace::MaceTensor(input_shape, buffer_in); + } +} + +void GenerateOutputs(const std::vector &output_names, + const std::vector &output_shape, + std::map *outputs) { + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + int64_t output_size = + std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + auto buffer_out = std::shared_ptr(new float[output_size], + std::default_delete()); + (*outputs)[output_names[i]] = mace::MaceTensor(output_shape, buffer_out); + } +} + +template +void BufferToImage(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + const std::vector &mem_ids, + NetDef *net_def, + const int mode = NetMode::NORMAL) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("BufferToImage", "BufferToImageOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("mode", mode) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void ImageToBuffer(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + NetDef *net_def) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("ImageToBuffer", "ImageToBufferOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Conv3x3(const std::string &input_name, + const std::string &filter_name, + const std::string &output_name, + const std::vector &mem_ids, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Conv2D", "Conv2dOp") + .Input(input_name) + .Input(filter_name) + .Output(output_name) + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", Padding::SAME) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Relu(const std::string &input_name, + const std::string &output_name, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Activation", "ReluTest") + .Input(input_name) + .Output(output_name) + .AddStringArg("activation", "RELU") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void AddTensor(const std::string &name, + const std::vector &shape, + T *data, + NetDef *net_def) { + ConstTensor tensor(name, + reinterpret_cast(data), + shape, + DataTypeToEnum::value); + + net_def->mutable_tensors().push_back(tensor); +} + +template +void CheckOutputs(const NetDef &net_def, + const std::map &inputs, + const std::map &outputs) { + ops::test::OpsTestNet net; + for (auto input : inputs) { + auto input_shape = input.second.shape(); + const int64_t data_size = std::accumulate(input_shape.begin(), + input_shape.end(), 1, + std::multiplies()); + std::vector input_data(data_size); + memcpy(input_data.data(), input.second.data().get(), + data_size * sizeof(float)); + std::string input_name = MakeString("mace_input_node_", + input.first, ":0"); + net.AddInputFromArray(input_name, input.second.shape(), + input_data); + } + auto tensors = net_def.tensors(); + for (auto tensor : tensors) { + auto shape = tensor.dims(); + const int64_t data_size = std::accumulate(shape.begin(), + shape.end(), 1, + std::multiplies()); + std::vector data(data_size); + memcpy(data.data(), reinterpret_cast(tensor.data()), + data_size * sizeof(T)); + net.AddInputFromArray(tensor.name(), shape, data); + } + net.RunNet(net_def, D); + + for (auto output : outputs) { + std::unique_ptr tmp_tensor( + new Tensor(GetDeviceAllocator(DeviceType::CPU), + DataTypeToEnum::v())); + auto output_shape = output.second.shape(); + const int64_t data_size = std::accumulate(output_shape.begin(), + output_shape.end(), 1, + std::multiplies()); + tmp_tensor->Resize(output.second.shape()); + float *data = tmp_tensor->mutable_data(); + memcpy(data, output.second.data().get(), data_size * sizeof(float)); + std::string output_name = MakeString("mace_output_node_", + output.first, ":0"); + ops::test::ExpectTensorNear(*tmp_tensor, + *net.GetOutput(output_name.data()), + 1e-5); + } +} + +std::map AddMemoryOptimization( + const std::vector &input_names, + const std::vector &output_names, + const std::vector> &input_shapes, + const std::vector> &output_shapes, + NetDef *net_def) { + std::map res; + int mem_id = 0; + size_t input_shape_size = input_shapes.size(); + uint32_t in_mem_block_x = 0; + uint32_t in_mem_block_y = 0; + for (size_t i = 0; i < input_shape_size; ++i) { + in_mem_block_x = std::max(in_mem_block_x, + input_shapes[i][2] * + RoundUpDiv4(input_shapes[i][3])); + in_mem_block_y = std::max(in_mem_block_y, + input_shapes[i][0] * + input_shapes[i][1]); + } + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, in_mem_block_x, in_mem_block_y)); + res[input_names[i]] = mem_id; + mem_id++; + } + size_t output_shape_size = output_shapes.size(); + uint32_t out_mem_block_x = 0; + uint32_t out_mem_block_y = 0; + for (size_t i = 0; i < output_shape_size; ++i) { + out_mem_block_x = std::max(out_mem_block_x, + output_shapes[i][2] * + RoundUpDiv4(output_shapes[i][3])); + out_mem_block_y = std::max(out_mem_block_y, + output_shapes[i][0] * + output_shapes[i][1]); + } + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, out_mem_block_x, out_mem_block_y)); + res[output_names[i]] = mem_id; + mem_id++; + } + return res; +} + +// The height and width of input and output must be equal. +template +void MaceRun(const int in_out_size, + const std::vector> &input_shapes, + const std::vector> &output_shapes, + const std::vector &filter_shape) { + std::vector input_names; + std::vector output_names; + for (int i = 0; i < in_out_size; ++i) { + input_names.push_back(MakeString("input", i)); + output_names.push_back(MakeString("output", i)); + } + std::string filter_tensor_name = "filter"; + std::string filter_tensor_img_name = filter_tensor_name + "_image"; + + const DeviceType device = DeviceType::OPENCL; + + NetDef net_def; + + // Add memory optimization + auto mem_map = AddMemoryOptimization(input_names, output_names, + input_shapes, output_shapes, + &net_def); + + std::vector data; + ops::test::GenerateRandomRealTypeData(filter_shape, &data); + AddTensor(filter_tensor_name, filter_shape, data.data(), &net_def); + + for (size_t i = 0; i < input_names.size(); ++i) { + std::string input_name = MakeString("mace_input_node_", + input_names[i], ":0"); + BufferToImage(input_name, input_names[i], + mace::kernels::IN_OUT_CHANNEL, + {mem_map[input_names[i]]}, + &net_def); + } + BufferToImage(filter_tensor_name, filter_tensor_img_name, + mace::kernels::CONV2D_FILTER, {}, + &net_def, NetMode::INIT); + for (size_t i = 0; i < output_names.size(); ++i) { + Conv3x3(input_names[i], filter_tensor_img_name, + output_names[i], {mem_map[output_names[i]]}, + &net_def); + } + for (size_t i = 0; i < output_names.size(); ++i) { + std::string output_name = MakeString("mace_output_node_", + output_names[i], ":0"); + ImageToBuffer(output_names[i], output_name, + mace::kernels::IN_OUT_CHANNEL, &net_def); + } + + MaceEngine engine(&net_def, device, input_names, output_names); + + std::map inputs; + std::map outputs; + + for (int i = 0; i < 5; ++i) { + size_t input_shape_size = input_shapes.size(); + for (size_t j = 0; j < input_shape_size; ++j) { + inputs.clear(); + outputs.clear(); + GenerateInputs(input_names, input_shapes[j], &inputs); + GenerateOutputs(output_names, output_shapes[j], &outputs); + engine.Run(inputs, &outputs); + } + } + + CheckOutputs(net_def, inputs, outputs); +} + +} // namespace + +TEST_F(MaceAPITest, GPUSingleInputOutput) { + MaceRun(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16}); + MaceRun(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16}); +} + +TEST_F(MaceAPITest, GPUMultipleInputOutput) { + MaceRun(2, + {{1, 16, 32, 16}}, + {{1, 16, 32, 16}}, + {3, 3, 16, 16}); + MaceRun(2, + {{1, 16, 32, 16}}, + {{1, 16, 32, 16}}, + {3, 3, 16, 16}); +} + +TEST_F(MaceAPITest, GPUVariableInputShape) { + MaceRun(1, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {3, 3, 16, 16}); + MaceRun(2, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {3, 3, 16, 16}); +} + +} // namespace test +} // namespace mace diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h index fa25f6daaa5fad5ea786e83fec88d53e5e92d82e..db4f25fa8288cc65c094017c08b63465345fd5be 100644 --- a/mace/utils/tuner.h +++ b/mace/utils/tuner.h @@ -94,8 +94,8 @@ class Tuner { Tuner &operator=(const Tuner &) = delete; inline void WriteRunParameters() { - VLOG(3) << "Write tuning result to " << path_; if (path_ != nullptr) { + VLOG(3) << "Write tuning result to " << path_; std::ofstream ofs(path_, std::ios::binary | std::ios::out); if (ofs.is_open()) { int64_t num_pramas = param_table_.size(); diff --git a/tools/validate.py b/tools/validate.py index 18e54faf6b661746445ec7c47ffb063ece7314e1..2c76ecae5626ba8b8de6037de9becb3b5809094f 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -42,7 +42,7 @@ def load_data(file): return np.empty([0]) -def format_output_name(name): +def format_name(name): return re.sub('[^0-9a-zA-Z]+', '_', name) @@ -87,7 +87,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, input_dict = {} for i in range(len(input_names)): input_value = load_data( - input_file + "_" + input_names[i]) + input_file + "_" + format_name(input_names[i])) input_value = input_value.reshape(input_shapes[i]) input_node = graph.get_tensor_by_name( input_names[i] + ':0') @@ -100,7 +100,7 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file, output_values = session.run(output_nodes, feed_dict=input_dict) for i in range(len(output_names)): output_file_name = mace_out_file + "_" + \ - format_output_name(output_names[i]) + format_name(output_names[i]) mace_out_value = load_data(output_file_name) compare_output(platform, mace_runtime, output_names[i], mace_out_value, output_values[i]) @@ -123,7 +123,7 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file, net = caffe.Net(model_file, caffe.TEST, weights=weight_file) for i in range(len(input_names)): - input_value = load_data(input_file + "_" + input_names[i]) + input_value = load_data(input_file + "_" + format_name(input_names[i])) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) input_blob_name = input_names[i] @@ -142,7 +142,7 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file, out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[ 1], out_shape[2] value = value.reshape(out_shape).transpose((0, 2, 3, 1)) - output_file_name = mace_out_file + "_" + format_output_name( + output_file_name = mace_out_file + "_" + format_name( output_names[i]) mace_out_value = load_data(output_file_name) compare_output(platform, mace_runtime, output_names[i], mace_out_value,