From 5c470a5c6db228de8f1b50e164dbd5f54b81841e Mon Sep 17 00:00:00 2001 From: MRXLT Date: Mon, 27 Apr 2020 20:02:47 +0800 Subject: [PATCH] bug fix --- core/general-client/src/general_model.cpp | 24 +++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index b78df167..5d4f732f 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -395,6 +395,19 @@ int PredictorClient::numpy_predict( tensor->set_elem_type(1); const int float_shape_size = float_shape[vec_idx].size(); switch (float_shape_size) { + case 4: { + auto float_array = float_feed[vec_idx].unchecked<4>(); + for (ssize_t i = 0; i < float_array.shape(0); i++) { + for (ssize_t j = 0; j < float_array.shape(1); j++) { + for (ssize_t k = 0; k < float_array.shape(2); k++) { + for (ssize_t l = 0; l < float_array.shape(3); l++) { + tensor->add_float_data(float_array(i, j, k, l)); + } + } + } + } + break; + } case 3: { auto float_array = float_feed[vec_idx].unchecked<3>(); for (ssize_t i = 0; i < float_array.shape(0); i++) { @@ -415,6 +428,13 @@ int PredictorClient::numpy_predict( } break; } + case 1: { + auto float_array = float_feed[vec_idx].unchecked<1>(); + for (ssize_t i = 0; i < float_array.shape(0); i++) { + tensor->add_float_data(float_array(i)); + } + break; + } } vec_idx++; } @@ -436,7 +456,7 @@ int PredictorClient::numpy_predict( const int int_shape_size = int_shape[vec_idx].size(); switch (int_shape_size) { case 4: { - auto int_array = int_feed[vec_idx].unchecked<3>(); + auto int_array = int_feed[vec_idx].unchecked<4>(); for (ssize_t i = 0; i < int_array.shape(0); i++) { for (ssize_t j = 0; j < int_array.shape(1); j++) { for (ssize_t k = 0; k < int_array.shape(2); k++) { @@ -469,7 +489,7 @@ int PredictorClient::numpy_predict( break; } case 1: { - auto int_array = int_feed[vec_idx].unchecked<2>(); + auto int_array = int_feed[vec_idx].unchecked<1>(); for (ssize_t i = 0; i < int_array.shape(0); i++) { tensor->add_float_data(int_array(i)); } -- GitLab