提交 a376a1b8 编写于 作者: Y yejianwu

Merge branch 'master' of v9.git.n.xiaomi.com:deep-computing/mace into add_shared_lib

......@@ -106,6 +106,8 @@ class MaceEngine::Impl {
DeviceType device_type_;
std::unique_ptr<Workspace> ws_;
std::unique_ptr<NetBase> net_;
std::map<std::string, mace::InputInfo> input_info_map_;
std::map<std::string, mace::OutputInfo> output_info_map_;
#ifdef MACE_ENABLE_HEXAGON
std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif
......@@ -131,12 +133,29 @@ MaceStatus MaceEngine::Impl::Init(
const std::vector<std::string> &output_nodes,
const unsigned char *model_data) {
LOG(INFO) << "Initializing MaceEngine";
// Get input and output information.
for (auto &input_info : net_def->input_info()) {
input_info_map_[input_info.name()] = input_info;
}
for (auto &output_info : net_def->output_info()) {
output_info_map_[output_info.name()] = output_info;
}
// Set storage path for internal usage
for (auto input_name : input_nodes) {
if (input_info_map_.find(input_name) == input_info_map_.end()) {
LOG(FATAL) << "'" << input_name
<< "' is not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_));
}
ws_->CreateTensor(MakeString("mace_input_node_", input_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
for (auto output_name : output_nodes) {
if (output_info_map_.find(output_name) == output_info_map_.end()) {
LOG(FATAL) << "'" << output_name
<< "' is not belong to model's outputs "
<< MakeString(MapKeys(output_info_map_));
}
ws_->CreateTensor(MakeString("mace_output_node_", output_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
......@@ -193,6 +212,11 @@ MaceStatus MaceEngine::Impl::Run(
std::vector<Tensor *> input_tensors;
std::vector<Tensor *> output_tensors;
for (auto &input : inputs) {
if (input_info_map_.find(input.first) == input_info_map_.end()) {
LOG(FATAL) << "'" << input.first
<< "' is not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_));
}
MACE_CHECK(input.second.shape().size() == 4,
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
......@@ -208,6 +232,11 @@ MaceStatus MaceEngine::Impl::Run(
input_tensors.push_back(input_tensor);
}
for (auto &output : *outputs) {
if (output_info_map_.find(output.first) == output_info_map_.end()) {
LOG(FATAL) << "'" << output.first
<< "' is not belong to model's outputs: "
<< MakeString(MapKeys(output_info_map_));
}
if (device_type_ == DeviceType::GPU) {
MACE_CHECK(output.second.shape().size() == 4,
"The outputs' shape must be 4-dimension with NHWC format,"
......@@ -245,7 +274,7 @@ MaceStatus MaceEngine::Impl::Run(
std::multiplies<int64_t>());
MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0";
MACE_CHECK(shape == output.second.shape())
<< "Output shape mispatch: "
<< "Output shape mismatch: "
<< MakeString<int64_t>(output.second.shape())
<< " != " << MakeString<int64_t>(shape);
std::memcpy(output.second.data().get(), output_tensor->data<float>(),
......
......@@ -281,7 +281,9 @@ bool OpenCLLibraryImpl::Load() {
}
if (handle_ == nullptr) {
LOG(ERROR) << "Failed to load OpenCL library";
LOG(ERROR) << "Failed to load OpenCL library, "
"please make sure there exist OpenCL library on your device, "
"and your APP have right to access the library.";
return false;
}
......
......@@ -50,7 +50,7 @@ inline void GemmBlock(const float *A,
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RA, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RA, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RA, 2); \
......@@ -60,7 +60,7 @@ inline void GemmBlock(const float *A,
c##RC = vfmaq_laneq_f32(c##RC, b6, a##RAN, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b7, a##RAN, 3);
#else
#define MACE_GEMM_PART_CAL(RC, RA, RAN) \
#define MACE_GEMM_PART_CAL_8(RC, RA, RAN) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RA), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RA), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RA), 0); \
......@@ -72,6 +72,283 @@ inline void GemmBlock(const float *A,
#endif
#endif
#if defined(MACE_ENABLE_NEON)
#if defined(__aarch64__)
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vfmaq_laneq_f32(c##RC, b0, a##RC, 0); \
c##RC = vfmaq_laneq_f32(c##RC, b1, a##RC, 1); \
c##RC = vfmaq_laneq_f32(c##RC, b2, a##RC, 2); \
c##RC = vfmaq_laneq_f32(c##RC, b3, a##RC, 3);
#else
#define MACE_GEMM_PART_CAL_4(RC) \
c##RC = vmlaq_lane_f32(c##RC, b0, vget_low_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b1, vget_low_f32(a##RC), 1); \
c##RC = vmlaq_lane_f32(c##RC, b2, vget_high_f32(a##RC), 0); \
c##RC = vmlaq_lane_f32(c##RC, b3, vget_high_f32(a##RC), 1);
#endif
#endif
inline void Gemm144(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
MACE_UNUSED(stride_a);
MACE_UNUSED(stride_c);
float32x4_t a0;
float32x4_t b0, b1, b2, b3;
float32x4_t c0;
a0 = vld1q_f32(a_ptr);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
MACE_GEMM_PART_CAL_4(0);
vst1q_f32(c_ptr, c0);
#else
GemmBlock(a_ptr, b_ptr, 1, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm244(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
#else
GemmBlock(a_ptr, b_ptr, 2, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm344(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
#else
GemmBlock(a_ptr, b_ptr, 3, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm444(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
#else
GemmBlock(a_ptr, b_ptr, 4, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm544(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
#else
GemmBlock(a_ptr, b_ptr, 5, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void Gemm644(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr) {
#if defined(MACE_ENABLE_NEON)
float32x4_t a0, a1, a2, a3, a4, a5;
float32x4_t b0, b1, b2, b3;
float32x4_t c0, c1, c2, c3, c4, c5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
b0 = vld1q_f32(b_ptr);
b1 = vld1q_f32(b_ptr + 1 * stride_b);
b2 = vld1q_f32(b_ptr + 2 * stride_b);
b3 = vld1q_f32(b_ptr + 3 * stride_b);
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
MACE_GEMM_PART_CAL_4(0);
MACE_GEMM_PART_CAL_4(1);
MACE_GEMM_PART_CAL_4(2);
MACE_GEMM_PART_CAL_4(3);
MACE_GEMM_PART_CAL_4(4);
MACE_GEMM_PART_CAL_4(5);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
vst1q_f32(c_ptr + 2 * stride_c, c2);
vst1q_f32(c_ptr + 3 * stride_c, c3);
vst1q_f32(c_ptr + 4 * stride_c, c4);
vst1q_f32(c_ptr + 5 * stride_c, c5);
#else
GemmBlock(a_ptr, b_ptr, 6, 4, 4, stride_a, stride_b, stride_c, c_ptr);
#endif
}
inline void GemmX44(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
const index_t stride_b,
const index_t stride_c,
float *c_ptr,
int row) {
switch (row) {
case 1:
Gemm144(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 2:
Gemm244(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 3:
Gemm344(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 4:
Gemm444(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 5:
Gemm544(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
case 6:
Gemm644(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
break;
default:
MACE_NOT_IMPLEMENTED;
}
}
inline void Gemm884(const float *a_ptr,
const float *b_ptr,
const index_t stride_a,
......@@ -119,25 +396,14 @@ inline void Gemm884(const float *a_ptr,
c6 = vld1q_f32(c_ptr + 6 * stride_c);
c7 = vld1q_f32(c_ptr + 7 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
MACE_GEMM_PART_CAL(7, 14, 15);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
MACE_GEMM_PART_CAL(7, 14, 15);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL_8(6, 12, 13);
MACE_GEMM_PART_CAL_8(7, 14, 15);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -180,11 +446,7 @@ inline void Gemm184(const float *a_ptr,
c0 = vld1q_f32(c_ptr);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
vst1q_f32(c_ptr, c0);
#else
......@@ -220,13 +482,8 @@ inline void Gemm284(const float *a_ptr,
c0 = vld1q_f32(c_ptr);
c1 = vld1q_f32(c_ptr + 1 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -266,15 +523,9 @@ inline void Gemm384(const float *a_ptr,
c1 = vld1q_f32(c_ptr + 1 * stride_c);
c2 = vld1q_f32(c_ptr + 2 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -318,17 +569,10 @@ inline void Gemm484(const float *a_ptr,
c2 = vld1q_f32(c_ptr + 2 * stride_c);
c3 = vld1q_f32(c_ptr + 3 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -376,19 +620,11 @@ inline void Gemm584(const float *a_ptr,
c3 = vld1q_f32(c_ptr + 3 * stride_c);
c4 = vld1q_f32(c_ptr + 4 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -440,21 +676,12 @@ inline void Gemm684(const float *a_ptr,
c4 = vld1q_f32(c_ptr + 4 * stride_c);
c5 = vld1q_f32(c_ptr + 5 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -511,23 +738,13 @@ inline void Gemm784(const float *a_ptr,
c5 = vld1q_f32(c_ptr + 5 * stride_c);
c6 = vld1q_f32(c_ptr + 6 * stride_c);
#if defined(__aarch64__)
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
#else
MACE_GEMM_PART_CAL(0, 0, 1);
MACE_GEMM_PART_CAL(1, 2, 3);
MACE_GEMM_PART_CAL(2, 4, 5);
MACE_GEMM_PART_CAL(3, 6, 7);
MACE_GEMM_PART_CAL(4, 8, 9);
MACE_GEMM_PART_CAL(5, 10, 11);
MACE_GEMM_PART_CAL(6, 12, 13);
#endif
MACE_GEMM_PART_CAL_8(0, 0, 1);
MACE_GEMM_PART_CAL_8(1, 2, 3);
MACE_GEMM_PART_CAL_8(2, 4, 5);
MACE_GEMM_PART_CAL_8(3, 6, 7);
MACE_GEMM_PART_CAL_8(4, 8, 9);
MACE_GEMM_PART_CAL_8(5, 10, 11);
MACE_GEMM_PART_CAL_8(6, 12, 13);
vst1q_f32(c_ptr, c0);
vst1q_f32(c_ptr + 1 * stride_c, c1);
......@@ -589,9 +806,19 @@ inline void GemmTile(const float *A,
const index_t stride_c,
float *C) {
#if defined(MACE_ENABLE_NEON)
index_t h, w, k;
for (h = 0; h < height - 7; h += 8) {
for (k = 0; k < K - 7; k += 8) {
index_t h = 0;
index_t w = 0;
index_t k = 0;
#if defined(__aarch64__)
int reg_height_tile = 8;
int reg_K_tile = 8;
#else
int reg_height_tile = 6;
int reg_K_tile = 4;
#endif
for (h = 0; h < height - reg_height_tile + 1; h += reg_height_tile) {
for (k = 0; k < K - reg_K_tile + 1; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
#if defined(__aarch64__) && defined(__clang__)
int nw = width >> 2;
......@@ -833,43 +1060,180 @@ inline void GemmTile(const float *A,
w = (width >> 2) << 2;
}
#else // gcc || armv7a
#elif defined(__aarch64__) // gcc
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
Gemm884(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr);
}
#endif // clang && armv8a
#else // armv7
int nw = width >> 2;
if (nw > 0) {
float32x4_t a0, a1, a2, a3, a4, a5;
a0 = vld1q_f32(a_ptr);
a1 = vld1q_f32(a_ptr + 1 * stride_a);
a2 = vld1q_f32(a_ptr + 2 * stride_a);
a3 = vld1q_f32(a_ptr + 3 * stride_a);
a4 = vld1q_f32(a_ptr + 4 * stride_a);
a5 = vld1q_f32(a_ptr + 5 * stride_a);
const float *b_ptr0 = B + k * stride_b;
const float *b_ptr1 = B + (k + 1) * stride_b;
const float *b_ptr2 = B + (k + 2) * stride_b;
const float *b_ptr3 = B + (k + 3) * stride_b;
float *c_ptr0 = C + h * stride_c;
float *c_ptr1 = C + (h + 1) * stride_c;
float *c_ptr2 = C + (h + 2) * stride_c;
float *c_ptr3 = C + (h + 3) * stride_c;
float *c_ptr4 = C + (h + 4) * stride_c;
float *c_ptr5 = C + (h + 5) * stride_c;
asm volatile(
"pld [%7, #128] \n"
"vld1.f32 {d12-d13}, [%7]! \n"
"pld [%1, #128] \n"
"vld1.f32 {d16-d17}, [%1] \n"
"pld [%2, #128] \n"
"vld1.f32 {d18-d19}, [%2] \n"
"0: \n"
"pld [%3, #128] \n"
"vld1.f32 {d20-d21}, [%3] \n"
"pld [%4, #128] \n"
"vld1.f32 {d22-d23}, [%4] \n"
"pld [%5, #128] \n"
"vld1.f32 {d24-d25}, [%5] \n"
"pld [%6, #128] \n"
"vld1.f32 {d26-d27}, [%6] \n"
"pld [%8, #128] \n"
"vld1.f32 {d14-d15}, [%8]! \n"
"vmla.f32 q8, q6, %e22[0] \n"
"vmla.f32 q9, q6, %e23[0] \n"
"vmla.f32 q10, q6, %e24[0] \n"
"vmla.f32 q11, q6, %e25[0] \n"
"vmla.f32 q12, q6, %e26[0] \n"
"vmla.f32 q13, q6, %e27[0] \n"
"pld [%9, #128] \n"
"vld1.f32 {d12-d13}, [%9]! \n"
"vmla.f32 q8, q7, %e22[1] \n"
"vmla.f32 q9, q7, %e23[1] \n"
"vmla.f32 q10, q7, %e24[1] \n"
"vmla.f32 q11, q7, %e25[1] \n"
"vmla.f32 q12, q7, %e26[1] \n"
"vmla.f32 q13, q7, %e27[1] \n"
"pld [%10, #128] \n"
"vld1.f32 {d14-d15}, [%10]! \n"
"vmla.f32 q8, q6, %f22[0] \n"
"vmla.f32 q9, q6, %f23[0] \n"
"vmla.f32 q10, q6, %f24[0] \n"
"vmla.f32 q11, q6, %f25[0] \n"
"vmla.f32 q12, q6, %f26[0] \n"
"vmla.f32 q13, q6, %f27[0] \n"
"vmla.f32 q8, q7, %f22[1] \n"
"vmla.f32 q9, q7, %f23[1] \n"
"vmla.f32 q10, q7, %f24[1] \n"
"vmla.f32 q11, q7, %f25[1] \n"
"vmla.f32 q12, q7, %f26[1] \n"
"vmla.f32 q13, q7, %f27[1] \n"
"vst1.f32 {d16-d17}, [%1]! \n"
"vst1.f32 {d18-d19}, [%2]! \n"
"pld [%7, #128] \n"
"vld1.f32 {d12-d13}, [%7]! \n"
"vst1.f32 {d20-d21}, [%3]! \n"
"vst1.f32 {d22-d23}, [%4]! \n"
"pld [%1, #128] \n"
"vld1.f32 {d16-d17}, [%1] \n"
"vst1.f32 {d24-d25}, [%5]! \n"
"vst1.f32 {d26-d27}, [%6]! \n"
"pld [%2, #128] \n"
"vld1.f32 {d18-d19}, [%2] \n"
"subs %0, #1 \n"
"bne 0b \n"
: "=r"(nw), // 0
"=r"(c_ptr0), // 1
"=r"(c_ptr1), // 2
"=r"(c_ptr2), // 3
"=r"(c_ptr3), // 4
"=r"(c_ptr4), // 5
"=r"(c_ptr5), // 6
"=r"(b_ptr0), // 7
"=r"(b_ptr1), // 8
"=r"(b_ptr2), // 9
"=r"(b_ptr3) // 10
: "0"(nw), // 11
"1"(c_ptr0), // 12
"2"(c_ptr1), // 13
"3"(c_ptr2), // 14
"4"(c_ptr3), // 15
"5"(c_ptr4), // 16
"6"(c_ptr5), // 17
"7"(b_ptr0), // 18
"8"(b_ptr1), // 19
"9"(b_ptr2), // 20
"10"(b_ptr3), // 21
"w"(a0), // 22
"w"(a1), // 23
"w"(a2), // 24
"w"(a3), // 25
"w"(a4), // 26
"w"(a5) // 27
: "cc", "memory", "q6", "q7", "q8", "q9", "q10", "q11", "q12",
"q13", "q14", "q15");
w = (width >> 2) << 2;
}
#endif
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, 8, 8, width - w, stride_a, stride_b, stride_c,
c_ptr);
GemmBlock(a_ptr, b_ptr, reg_height_tile, reg_K_tile, width - w,
stride_a, stride_b, stride_c, c_ptr);
}
}
if (k < K) {
const float *a_ptr = A + (h * stride_a + k);
const float *b_ptr = B + k * stride_b;
float *c_ptr = C + h * stride_c;
GemmBlock(a_ptr, b_ptr, 8, K - k, width, stride_a, stride_b, stride_c,
c_ptr);
GemmBlock(a_ptr, b_ptr, reg_height_tile, K - k, width, stride_a, stride_b,
stride_c, c_ptr);
}
}
if (h < height) {
index_t remain_h = height - h;
for (k = 0; k < K - 7; k += 8) {
for (k = 0; k < K - reg_K_tile; k += reg_K_tile) {
const float *a_ptr = A + (h * stride_a + k);
index_t w;
for (w = 0; w + 3 < width; w += 4) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
#if defined(__aarch64__)
GemmX84(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
#else
GemmX44(a_ptr, b_ptr, stride_a, stride_b, stride_c, c_ptr, remain_h);
#endif
}
if (w < width) {
const float *b_ptr = B + (k * stride_b + w);
float *c_ptr = C + (h * stride_c + w);
GemmBlock(a_ptr, b_ptr, remain_h, 8, width - w, stride_a, stride_b,
stride_c, c_ptr);
GemmBlock(a_ptr, b_ptr, remain_h, reg_K_tile, width - w, stride_a,
stride_b, stride_c, c_ptr);
}
}
if (k < K) {
......
......@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
std::vector<int32_t> expected_input_shape(input_shape.begin(),
input_shape.end());
if (!expected_input_shape.empty()) {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput",
{static_cast<int32_t>(
input_shape.size())},
expected_input_shape);
} else {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {}, {0});
......
......@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("BeginIndices", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("BeginIndices",
{static_cast<int32_t>(
input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>("EndIndices", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("EndIndices",
{static_cast<int32_t>(
input_shape.size())},
end_indices);
net.AddInputFromArray<CPU, int32_t>("Strides", {input_shape.size()}, strides);
net.AddInputFromArray<CPU, int32_t>("Strides",
{static_cast<int32_t>(
input_shape.size())},
strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input")
......
......@@ -164,6 +164,7 @@ class TransformerRule(Enum):
TRANSFORM_BUFFER_IMAGE = 17
ADD_DEVICE_AND_DATA_TYPE = 18
SORT_BY_EXECUTION = 19
ADD_IN_OUT_TENSOR_INFO = 20
class ConverterInterface(object):
......@@ -210,6 +211,7 @@ class ConverterOption(object):
self._device = DeviceType.CPU.value
self._winograd_enabled = False
self._transformer_option = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
......
......@@ -166,6 +166,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._option = option
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.HWIO)
# import tensorflow graph
tf_graph_def = tf.GraphDef()
with tf.gfile.Open(src_model_file, 'rb') as f:
tf_graph_def.ParseFromString(f.read())
......
......@@ -55,6 +55,7 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model):
# DO NOT reorder the following transformers' order
self._registered_transformers_order = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
......@@ -78,6 +79,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION,
]
self._registered_transformers = {
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info,
TransformerRule.REMOVE_USELESS_RESHAPE_OP:
self.remove_useless_reshape_op,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
......@@ -271,6 +274,21 @@ class Transformer(base_converter.ConverterInterface):
self._model.op.remove(op)
def add_in_out_tensor_info(self):
net = self._model
for input_node in self._option.input_nodes.values():
input_info = net.input_info.add()
input_info.name = input_node.name
input_info.dims.extend(input_node.shape)
for output_node in self._option.output_nodes.values():
output_info = net.output_info.add()
output_info.name = output_node.name
output_info.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
return False
def remove_useless_reshape_op(self):
net = self._model
for op in net.op:
......
......@@ -50,3 +50,24 @@ cc_test(
"@gtest//:gtest_main",
],
)
cc_test(
name = "mace_api_exception_test",
testonly = 1,
srcs = ["mace_api_exception_test.cc"],
copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"] +
if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"@gtest//:gtest_main",
],
)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace test {
TEST(MaceAPIExceptionTest, WrongInputTest) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
input_names.push_back(MakeString("input", 0));
output_names.push_back(MakeString("output", 0));
const DeviceType device = DeviceType::GPU;
std::shared_ptr<NetDef> net_def(new NetDef());
for (size_t i = 0; i < input_names.size(); ++i) {
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
MaceEngine engine(device);
ASSERT_DEATH(engine.Init(net_def.get(), {"input"}, output_names, nullptr),
"");
}
} // namespace test
} // namespace mace
......@@ -298,6 +298,8 @@ void MaceRunFunc(const int in_out_size) {
{mem_map[input_names[i]]},
device,
net_def.get());
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, device,
......@@ -315,6 +317,8 @@ void MaceRunFunc(const int in_out_size) {
mace::kernels::IN_OUT_CHANNEL,
device,
net_def.get());
OutputInfo *info = net_def->add_output_info();
info->set_name(output_names[i]);
}
const std::string file_path ="/data/local/tmp/mace";
......
......@@ -308,6 +308,8 @@ void MaceRun(const int in_out_size,
{mem_map[input_names[i]]},
device,
net_def.get());
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, device,
......@@ -324,6 +326,8 @@ void MaceRun(const int in_out_size,
mace::kernels::IN_OUT_CHANNEL,
device,
net_def.get());
OutputInfo *info = net_def->add_output_info();
info->set_name(output_names[i]);
}
MaceEngine engine(device);
......@@ -376,5 +380,6 @@ TEST_F(MaceAPITest, GPUVariableInputShape) {
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{16, 16, 3, 3});
}
} // namespace test
} // namespace mace
......@@ -16,6 +16,7 @@
#define MACE_UTILS_UTILS_H_
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
......@@ -152,5 +153,14 @@ inline bool ReadBinaryFile(std::vector<unsigned char> *data,
return true;
}
template <typename T>
std::vector<std::string> MapKeys(const std::map<std::string, T> &data) {
std::vector<std::string> keys;
for (auto &kv : data) {
keys.push_back(kv.first);
}
return keys;
}
} // namespace mace
#endif // MACE_UTILS_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册