diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 519a6ef163bb8761ef9f91c3905e5d0bc34d7ebb..7177a42281b456ad074f413752c7ff86a7900f3f 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -106,6 +106,8 @@ class MaceEngine::Impl { DeviceType device_type_; std::unique_ptr ws_; std::unique_ptr net_; + std::map input_info_map_; + std::map output_info_map_; #ifdef MACE_ENABLE_HEXAGON std::unique_ptr hexagon_controller_; #endif @@ -131,12 +133,29 @@ MaceStatus MaceEngine::Impl::Init( const std::vector &output_nodes, const unsigned char *model_data) { LOG(INFO) << "Initializing MaceEngine"; + // Get input and output information. + for (auto &input_info : net_def->input_info()) { + input_info_map_[input_info.name()] = input_info; + } + for (auto &output_info : net_def->output_info()) { + output_info_map_[output_info.name()] = output_info; + } // Set storage path for internal usage for (auto input_name : input_nodes) { + if (input_info_map_.find(input_name) == input_info_map_.end()) { + LOG(FATAL) << "'" << input_name + << "' is not belong to model's inputs: " + << MakeString(MapKeys(input_info_map_)); + } ws_->CreateTensor(MakeString("mace_input_node_", input_name), GetDeviceAllocator(device_type_), DT_FLOAT); } for (auto output_name : output_nodes) { + if (output_info_map_.find(output_name) == output_info_map_.end()) { + LOG(FATAL) << "'" << output_name + << "' is not belong to model's outputs " + << MakeString(MapKeys(output_info_map_)); + } ws_->CreateTensor(MakeString("mace_output_node_", output_name), GetDeviceAllocator(device_type_), DT_FLOAT); } @@ -193,6 +212,11 @@ MaceStatus MaceEngine::Impl::Run( std::vector input_tensors; std::vector output_tensors; for (auto &input : inputs) { + if (input_info_map_.find(input.first) == input_info_map_.end()) { + LOG(FATAL) << "'" << input.first + << "' is not belong to model's inputs: " + << MakeString(MapKeys(input_info_map_)); + } MACE_CHECK(input.second.shape().size() == 4, "The Inputs' shape must be 4-dimension with NHWC format," " please use 1 to fill missing dimensions"); @@ -208,6 +232,11 @@ MaceStatus MaceEngine::Impl::Run( input_tensors.push_back(input_tensor); } for (auto &output : *outputs) { + if (output_info_map_.find(output.first) == output_info_map_.end()) { + LOG(FATAL) << "'" << output.first + << "' is not belong to model's outputs: " + << MakeString(MapKeys(output_info_map_)); + } if (device_type_ == DeviceType::GPU) { MACE_CHECK(output.second.shape().size() == 4, "The outputs' shape must be 4-dimension with NHWC format," @@ -245,7 +274,7 @@ MaceStatus MaceEngine::Impl::Run( std::multiplies()); MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0"; MACE_CHECK(shape == output.second.shape()) - << "Output shape mispatch: " + << "Output shape mismatch: " << MakeString(output.second.shape()) << " != " << MakeString(shape); std::memcpy(output.second.data().get(), output_tensor->data(), diff --git a/mace/core/runtime/opencl/opencl_wrapper.cc b/mace/core/runtime/opencl/opencl_wrapper.cc index fe1dc88c5b941787d6b4c19cc3c30bd702f8e12e..7a95b8a3c9960a1c7655c509c227793d0c9ba5b5 100644 --- a/mace/core/runtime/opencl/opencl_wrapper.cc +++ b/mace/core/runtime/opencl/opencl_wrapper.cc @@ -281,7 +281,9 @@ bool OpenCLLibraryImpl::Load() { } if (handle_ == nullptr) { - LOG(ERROR) << "Failed to load OpenCL library"; + LOG(ERROR) << "Failed to load OpenCL library, " + "please make sure there exist OpenCL library on your device, " + "and your APP have right to access the library."; return false; } diff --git a/mace/kernels/gemm.cc b/mace/kernels/gemm.cc index 40c90f58c0a4c6c8fda054194b6a5cced71cece6..0e05106fe0d6ef492c20f53b6afca9008445b062 100644 --- a/mace/kernels/gemm.cc +++ b/mace/kernels/gemm.cc @@ -50,7 +50,7 @@ inline void GemmBlock(const float *A, #if defined(MACE_ENABLE_NEON) #if defined(__aarch64__) -#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ +#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \ c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \ c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \ @@ -60,7 +60,7 @@ inline void GemmBlock(const float *A, c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \ c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3); #else -#define MACE_GEMM_PART_CAL(RC, RA, RAN) \ +#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \ c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \ c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \ c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \ @@ -72,6 +72,283 @@ inline void GemmBlock(const float *A, #endif #endif +#if defined(MACE_ENABLE_NEON) +#if defined(__aarch64__) +#define MACE_GEMM_PART_CAL_4(RC) \ + c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \ + c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \ + c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \ + c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3); +#else +#define MACE_GEMM_PART_CAL_4(RC) \ + c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \ + c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \ + c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1); +#endif +#endif + +inline void Gemm144(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + MACE_UNUSED(stride_a); + MACE_UNUSED(stride_c); + float32x4_t a0; + float32x4_t b0, b1, b2, b3; + float32x4_t c0; + + a0 = vld1q_f32(a_ptr); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + + MACE_GEMM_PART_CAL_4(0); + + vst1q_f32(c_ptr, c0); +#else + GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm244(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); +#else + GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm344(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); +#else + GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm444(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); +#else + GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm544(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3, c4; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + MACE_GEMM_PART_CAL_4(4); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); +#else + GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void Gemm644(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr) { +#if defined(MACE_ENABLE_NEON) + float32x4_t a0, a1, a2, a3, a4, a5; + float32x4_t b0, b1, b2, b3; + float32x4_t c0, c1, c2, c3, c4, c5; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + a5 = vld1q_f32(a_ptr + 5 * stride_a); + + b0 = vld1q_f32(b_ptr); + b1 = vld1q_f32(b_ptr + 1 * stride_b); + b2 = vld1q_f32(b_ptr + 2 * stride_b); + b3 = vld1q_f32(b_ptr + 3 * stride_b); + + c0 = vld1q_f32(c_ptr); + c1 = vld1q_f32(c_ptr + 1 * stride_c); + c2 = vld1q_f32(c_ptr + 2 * stride_c); + c3 = vld1q_f32(c_ptr + 3 * stride_c); + c4 = vld1q_f32(c_ptr + 4 * stride_c); + c5 = vld1q_f32(c_ptr + 5 * stride_c); + + MACE_GEMM_PART_CAL_4(0); + MACE_GEMM_PART_CAL_4(1); + MACE_GEMM_PART_CAL_4(2); + MACE_GEMM_PART_CAL_4(3); + MACE_GEMM_PART_CAL_4(4); + MACE_GEMM_PART_CAL_4(5); + + vst1q_f32(c_ptr, c0); + vst1q_f32(c_ptr + 1 * stride_c, c1); + vst1q_f32(c_ptr + 2 * stride_c, c2); + vst1q_f32(c_ptr + 3 * stride_c, c3); + vst1q_f32(c_ptr + 4 * stride_c, c4); + vst1q_f32(c_ptr + 5 * stride_c, c5); +#else + GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr); +#endif +} + +inline void GemmX44(const float *a_ptr, + const float *b_ptr, + const index_t stride_a, + const index_t stride_b, + const index_t stride_c, + float *c_ptr, + int row) { + switch (row) { + case 1: + Gemm144(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 2: + Gemm244(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 3: + Gemm344(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 4: + Gemm444(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 5: + Gemm544(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + case 6: + Gemm644(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); + break; + default: + MACE_NOT_IMPLEMENTED; + } +} + inline void Gemm884(const float *a_ptr, const float *b_ptr, const index_t stride_a, @@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr, c6 = vld1q_f32(c_ptr + 6 * stride_c); c7 = vld1q_f32(c_ptr + 7 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); - MACE_GEMM_PART_CAL(7, 14, 15); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); - MACE_GEMM_PART_CAL(7, 14, 15); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); + MACE_GEMM_PART_CAL_8(6, 12, 13); + MACE_GEMM_PART_CAL_8(7, 14, 15); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr, c0 = vld1q_f32(c_ptr); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); -#else - MACE_GEMM_PART_CAL(0, 0, 1); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); vst1q_f32(c_ptr, c0); #else @@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr, c0 = vld1q_f32(c_ptr); c1 = vld1q_f32(c_ptr + 1 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr, c1 = vld1q_f32(c_ptr + 1 * stride_c); c2 = vld1q_f32(c_ptr + 2 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr, c2 = vld1q_f32(c_ptr + 2 * stride_c); c3 = vld1q_f32(c_ptr + 3 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr, c3 = vld1q_f32(c_ptr + 3 * stride_c); c4 = vld1q_f32(c_ptr + 4 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr, c4 = vld1q_f32(c_ptr + 4 * stride_c); c5 = vld1q_f32(c_ptr + 5 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr, c5 = vld1q_f32(c_ptr + 5 * stride_c); c6 = vld1q_f32(c_ptr + 6 * stride_c); -#if defined(__aarch64__) - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); -#else - MACE_GEMM_PART_CAL(0, 0, 1); - MACE_GEMM_PART_CAL(1, 2, 3); - MACE_GEMM_PART_CAL(2, 4, 5); - MACE_GEMM_PART_CAL(3, 6, 7); - MACE_GEMM_PART_CAL(4, 8, 9); - MACE_GEMM_PART_CAL(5, 10, 11); - MACE_GEMM_PART_CAL(6, 12, 13); -#endif + MACE_GEMM_PART_CAL_8(0, 0, 1); + MACE_GEMM_PART_CAL_8(1, 2, 3); + MACE_GEMM_PART_CAL_8(2, 4, 5); + MACE_GEMM_PART_CAL_8(3, 6, 7); + MACE_GEMM_PART_CAL_8(4, 8, 9); + MACE_GEMM_PART_CAL_8(5, 10, 11); + MACE_GEMM_PART_CAL_8(6, 12, 13); vst1q_f32(c_ptr, c0); vst1q_f32(c_ptr + 1 * stride_c, c1); @@ -589,9 +806,19 @@ inline void GemmTile(const float *A, const index_t stride_c, float *C) { #if defined(MACE_ENABLE_NEON) - index_t h, w, k; - for (h = 0; h < height - 7; h += 8) { - for (k = 0; k < K - 7; k += 8) { + index_t h = 0; + index_t w = 0; + index_t k = 0; +#if defined(__aarch64__) + int reg_height_tile = 8; + int reg_K_tile = 8; +#else + int reg_height_tile = 6; + int reg_K_tile = 4; +#endif + + for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) { + for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) { const float *a_ptr = A + (h * stride_a + k); #if defined(__aarch64__) && defined(__clang__) int nw = width >> 2; @@ -833,43 +1060,180 @@ inline void GemmTile(const float *A, w = (width >> 2) << 2; } -#else // gcc || armv7a +#elif defined(__aarch64__) // gcc for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr); } -#endif // clang && armv8a +#else // armv7 + int nw = width >> 2; + if (nw > 0) { + float32x4_t a0, a1, a2, a3, a4, a5; + + a0 = vld1q_f32(a_ptr); + a1 = vld1q_f32(a_ptr + 1 * stride_a); + a2 = vld1q_f32(a_ptr + 2 * stride_a); + a3 = vld1q_f32(a_ptr + 3 * stride_a); + a4 = vld1q_f32(a_ptr + 4 * stride_a); + a5 = vld1q_f32(a_ptr + 5 * stride_a); + + const float *b_ptr0 = B + k * stride_b; + const float *b_ptr1 = B + (k + 1) * stride_b; + const float *b_ptr2 = B + (k + 2) * stride_b; + const float *b_ptr3 = B + (k + 3) * stride_b; + + float *c_ptr0 = C + h * stride_c; + float *c_ptr1 = C + (h + 1) * stride_c; + float *c_ptr2 = C + (h + 2) * stride_c; + float *c_ptr3 = C + (h + 3) * stride_c; + float *c_ptr4 = C + (h + 4) * stride_c; + float *c_ptr5 = C + (h + 5) * stride_c; + + asm volatile( + "pld [%7, #128] \n" + "vld1.f32 {d12-d13}, [%7]! \n" + "pld [%1, #128] \n" + "vld1.f32 {d16-d17}, [%1] \n" + "pld [%2, #128] \n" + "vld1.f32 {d18-d19}, [%2] \n" + + "0: \n" + + "pld [%3, #128] \n" + "vld1.f32 {d20-d21}, [%3] \n" + "pld [%4, #128] \n" + "vld1.f32 {d22-d23}, [%4] \n" + "pld [%5, #128] \n" + "vld1.f32 {d24-d25}, [%5] \n" + "pld [%6, #128] \n" + "vld1.f32 {d26-d27}, [%6] \n" + + "pld [%8, #128] \n" + "vld1.f32 {d14-d15}, [%8]! \n" + + "vmla.f32 q8, q6, %e22[0] \n" + "vmla.f32 q9, q6, %e23[0] \n" + "vmla.f32 q10, q6, %e24[0] \n" + "vmla.f32 q11, q6, %e25[0] \n" + "vmla.f32 q12, q6, %e26[0] \n" + "vmla.f32 q13, q6, %e27[0] \n" + + "pld [%9, #128] \n" + "vld1.f32 {d12-d13}, [%9]! \n" + + "vmla.f32 q8, q7, %e22[1] \n" + "vmla.f32 q9, q7, %e23[1] \n" + "vmla.f32 q10, q7, %e24[1] \n" + "vmla.f32 q11, q7, %e25[1] \n" + "vmla.f32 q12, q7, %e26[1] \n" + "vmla.f32 q13, q7, %e27[1] \n" + + "pld [%10, #128] \n" + "vld1.f32 {d14-d15}, [%10]! \n" + + "vmla.f32 q8, q6, %f22[0] \n" + "vmla.f32 q9, q6, %f23[0] \n" + "vmla.f32 q10, q6, %f24[0] \n" + "vmla.f32 q11, q6, %f25[0] \n" + "vmla.f32 q12, q6, %f26[0] \n" + "vmla.f32 q13, q6, %f27[0] \n" + + "vmla.f32 q8, q7, %f22[1] \n" + "vmla.f32 q9, q7, %f23[1] \n" + "vmla.f32 q10, q7, %f24[1] \n" + "vmla.f32 q11, q7, %f25[1] \n" + "vmla.f32 q12, q7, %f26[1] \n" + "vmla.f32 q13, q7, %f27[1] \n" + + "vst1.f32 {d16-d17}, [%1]! \n" + "vst1.f32 {d18-d19}, [%2]! \n" + + "pld [%7, #128] \n" + "vld1.f32 {d12-d13}, [%7]! \n" + + "vst1.f32 {d20-d21}, [%3]! \n" + "vst1.f32 {d22-d23}, [%4]! \n" + + "pld [%1, #128] \n" + "vld1.f32 {d16-d17}, [%1] \n" + + "vst1.f32 {d24-d25}, [%5]! \n" + "vst1.f32 {d26-d27}, [%6]! \n" + + "pld [%2, #128] \n" + "vld1.f32 {d18-d19}, [%2] \n" + + "subs %0, #1 \n" + "bne 0b \n" + : "=r"(nw), // 0 + "=r"(c_ptr0), // 1 + "=r"(c_ptr1), // 2 + "=r"(c_ptr2), // 3 + "=r"(c_ptr3), // 4 + "=r"(c_ptr4), // 5 + "=r"(c_ptr5), // 6 + "=r"(b_ptr0), // 7 + "=r"(b_ptr1), // 8 + "=r"(b_ptr2), // 9 + "=r"(b_ptr3) // 10 + : "0"(nw), // 11 + "1"(c_ptr0), // 12 + "2"(c_ptr1), // 13 + "3"(c_ptr2), // 14 + "4"(c_ptr3), // 15 + "5"(c_ptr4), // 16 + "6"(c_ptr5), // 17 + "7"(b_ptr0), // 18 + "8"(b_ptr1), // 19 + "9"(b_ptr2), // 20 + "10"(b_ptr3), // 21 + "w"(a0), // 22 + "w"(a1), // 23 + "w"(a2), // 24 + "w"(a3), // 25 + "w"(a4), // 26 + "w"(a5) // 27 + : "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12", + "q13", "q14", "q15"); + + w = (width >> 2) << 2; + } +#endif if (w < width) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c, - c_ptr); + GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w, + stride_a, stride_b, stride_c, c_ptr); } } if (k < K) { const float *a_ptr = A + (h * stride_a + k); const float *b_ptr = B + k * stride_b; float *c_ptr = C + h * stride_c; - GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c, - c_ptr); + GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b, + stride_c, c_ptr); } } if (h < height) { index_t remain_h = height - h; - for (k = 0; k < K - 7; k += 8) { + for (k = 0; k < K - reg_K_tile; k += reg_K_tile) { const float *a_ptr = A + (h * stride_a + k); index_t w; for (w = 0; w + 3 < width; w += 4) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); +#if defined(__aarch64__) GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); +#else + GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h); +#endif } if (w < width) { const float *b_ptr = B + (k * stride_b + w); float *c_ptr = C + (h * stride_c + w); - GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b, - stride_c, c_ptr); + GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a, + stride_b, stride_c, c_ptr); } } if (k < K) { diff --git a/mace/ops/shape_test.cc b/mace/ops/shape_test.cc index 5798be7f8309970445cb3c8bf10e6327c2f52144..08ccb88b86958bb4fdbd3a1677fe1b728355f5fe 100644 --- a/mace/ops/shape_test.cc +++ b/mace/ops/shape_test.cc @@ -38,7 +38,9 @@ void TestShapeOp(const std::vector &input_shape) { std::vector expected_input_shape(input_shape.begin(), input_shape.end()); if (!expected_input_shape.empty()) { - net.AddInputFromArray("ExpectedOutput", {input_shape.size()}, + net.AddInputFromArray("ExpectedOutput", + {static_cast( + input_shape.size())}, expected_input_shape); } else { net.AddInputFromArray("ExpectedOutput", {}, {0}); diff --git a/mace/ops/strided_slice_test.cc b/mace/ops/strided_slice_test.cc index 2aa4af2820488a7ea7fb0a293f05e7b7ad1802bf..6cd46f4e110e0cb001932a18db6db2a5c69d866b 100644 --- a/mace/ops/strided_slice_test.cc +++ b/mace/ops/strided_slice_test.cc @@ -37,11 +37,18 @@ void TestSlice(const std::vector &input_shape, const std::vector &output) { OpsTestNet net; net.AddInputFromArray("Input", input_shape, input); - net.AddInputFromArray("BeginIndices", {input_shape.size()}, + net.AddInputFromArray("BeginIndices", + {static_cast( + input_shape.size())}, begin_indices); - net.AddInputFromArray("EndIndices", {input_shape.size()}, + net.AddInputFromArray("EndIndices", + {static_cast( + input_shape.size())}, end_indices); - net.AddInputFromArray("Strides", {input_shape.size()}, strides); + net.AddInputFromArray("Strides", + {static_cast( + input_shape.size())}, + strides); OpDefBuilder("StridedSlice", "StridedSliceOpTest") .Input("Input") diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index 1d81fe2f58cd1fa7c67f8ddd1a4fa6dfe2fdfd5d..20ce14b48cd40a7352807dd44158f2e41d7fdb32 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -164,6 +164,7 @@ class TransformerRule(Enum): TRANSFORM_BUFFER_IMAGE = 17 ADD_DEVICE_AND_DATA_TYPE = 18 SORT_BY_EXECUTION = 19 + ADD_IN_OUT_TENSOR_INFO = 20 class ConverterInterface(object): @@ -210,6 +211,7 @@ class ConverterOption(object): self._device = DeviceType.CPU.value self._winograd_enabled = False self._transformer_option = [ + TransformerRule.ADD_IN_OUT_TENSOR_INFO, TransformerRule.REMOVE_USELESS_RESHAPE_OP, TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 0eea81c079cb26b9e1e84d09877f2912436057ba..b5f644e4ee6ae3410fbf1e998a8426ba07e1e512 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -166,6 +166,8 @@ class TensorflowConverter(base_converter.ConverterInterface): self._option = option self._mace_net_def = mace_pb2.NetDef() ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.HWIO) + + # import tensorflow graph tf_graph_def = tf.GraphDef() with tf.gfile.Open(src_model_file, 'rb') as f: tf_graph_def.ParseFromString(f.read()) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 3116cdc161be6a7b288080cae6ab2ba2d15a7f7d..8f4417d2ecf095891234e7e5b2a5541064d5cc01 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -55,6 +55,7 @@ class Transformer(base_converter.ConverterInterface): def __init__(self, option, model): # DO NOT reorder the following transformers' order self._registered_transformers_order = [ + TransformerRule.ADD_IN_OUT_TENSOR_INFO, TransformerRule.REMOVE_USELESS_RESHAPE_OP, TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, @@ -78,6 +79,8 @@ class Transformer(base_converter.ConverterInterface): TransformerRule.SORT_BY_EXECUTION, ] self._registered_transformers = { + TransformerRule.ADD_IN_OUT_TENSOR_INFO: + self.add_in_out_tensor_info, TransformerRule.REMOVE_USELESS_RESHAPE_OP: self.remove_useless_reshape_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, @@ -271,6 +274,21 @@ class Transformer(base_converter.ConverterInterface): self._model.op.remove(op) + def add_in_out_tensor_info(self): + net = self._model + for input_node in self._option.input_nodes.values(): + input_info = net.input_info.add() + input_info.name = input_node.name + input_info.dims.extend(input_node.shape) + + for output_node in self._option.output_nodes.values(): + output_info = net.output_info.add() + output_info.name = output_node.name + output_info.dims.extend( + self._producer[output_node.name].output_shape[0].dims) + + return False + def remove_useless_reshape_op(self): net = self._model for op in net.op: diff --git a/mace/test/BUILD b/mace/test/BUILD index afc2738aa49aa05335d333ea04b8ab379c221f48..c1f14bb772c711fb5b62066e7cfb5dd031a797ac 100644 --- a/mace/test/BUILD +++ b/mace/test/BUILD @@ -50,3 +50,24 @@ cc_test( "@gtest//:gtest_main", ], ) + +cc_test( + name = "mace_api_exception_test", + testonly = 1, + srcs = ["mace_api_exception_test.cc"], + copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"] + + if_openmp_enabled(["-fopenmp"]) + + if_neon_enabled(["-DMACE_ENABLE_NEON"]) + + if_android_armv7(["-mfpu=neon"]) + + if_android_armv7(["-mfloat-abi=softfp"]) + + if_android(["-DMACE_ENABLE_OPENCL"]) + + if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]), + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + "//mace/ops:test", + "//mace/kernels:kernels", + "//mace/ops:ops", + "@gtest//:gtest_main", + ], +) diff --git a/mace/test/mace_api_exception_test.cc b/mace/test/mace_api_exception_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..1eaad03726165987ce00c6df70d0b23f438a2231 --- /dev/null +++ b/mace/test/mace_api_exception_test.cc @@ -0,0 +1,40 @@ +// Copyright 2018 Xiaomi, Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace test { + +TEST(MaceAPIExceptionTest, WrongInputTest) { + std::vector input_names; + std::vector output_names; + input_names.push_back(MakeString("input", 0)); + output_names.push_back(MakeString("output", 0)); + + const DeviceType device = DeviceType::GPU; + + std::shared_ptr net_def(new NetDef()); + for (size_t i = 0; i < input_names.size(); ++i) { + InputInfo *info = net_def->add_input_info(); + info->set_name(input_names[i]); + } + + MaceEngine engine(device); + ASSERT_DEATH(engine.Init(net_def.get(), {"input"}, output_names, nullptr), + ""); +} + +} // namespace test +} // namespace mace diff --git a/mace/test/mace_api_mt_test.cc b/mace/test/mace_api_mt_test.cc index ab4317d45fb9f5ae46c27c47de6d05b5680c8be6..f19a67658eff99195085f69cfaff556cc89fa246 100644 --- a/mace/test/mace_api_mt_test.cc +++ b/mace/test/mace_api_mt_test.cc @@ -298,6 +298,8 @@ void MaceRunFunc(const int in_out_size) { {mem_map[input_names[i]]}, device, net_def.get()); + InputInfo *info = net_def->add_input_info(); + info->set_name(input_names[i]); } BufferToImage(filter_tensor_name, filter_tensor_img_name, mace::kernels::CONV2D_FILTER, {}, device, @@ -315,6 +317,8 @@ void MaceRunFunc(const int in_out_size) { mace::kernels::IN_OUT_CHANNEL, device, net_def.get()); + OutputInfo *info = net_def->add_output_info(); + info->set_name(output_names[i]); } const std::string file_path ="/data/local/tmp/mace"; diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc index f061ecc3f0864019ec1ea8efb4569aebe3ed49e0..874e221ed08675d5d32af235aadc3fc1ded34e0e 100644 --- a/mace/test/mace_api_test.cc +++ b/mace/test/mace_api_test.cc @@ -308,6 +308,8 @@ void MaceRun(const int in_out_size, {mem_map[input_names[i]]}, device, net_def.get()); + InputInfo *info = net_def->add_input_info(); + info->set_name(input_names[i]); } BufferToImage(filter_tensor_name, filter_tensor_img_name, mace::kernels::CONV2D_FILTER, {}, device, @@ -324,6 +326,8 @@ void MaceRun(const int in_out_size, mace::kernels::IN_OUT_CHANNEL, device, net_def.get()); + OutputInfo *info = net_def->add_output_info(); + info->set_name(output_names[i]); } MaceEngine engine(device); @@ -376,5 +380,6 @@ TEST_F(MaceAPITest, GPUVariableInputShape) { {{1, 16, 32, 16}, {1, 32, 64, 16}}, {16, 16, 3, 3}); } + } // namespace test } // namespace mace diff --git a/mace/utils/utils.h b/mace/utils/utils.h index b722176869ca08ff0576a3d02163a492ce51d70f..f8a8b9fe1b0155c1d5c0191c2e31f2d0ff2cfae6 100644 --- a/mace/utils/utils.h +++ b/mace/utils/utils.h @@ -16,6 +16,7 @@ #define MACE_UTILS_UTILS_H_ #include +#include #include #include #include @@ -152,5 +153,14 @@ inline bool ReadBinaryFile(std::vector *data, return true; } +template +std::vector MapKeys(const std::map &data) { + std::vector keys; + for (auto &kv : data) { + keys.push_back(kv.first); + } + return keys; +} + } // namespace mace #endif // MACE_UTILS_UTILS_H_