diff --git a/core/general-client/src/general_model.cpp b/core/general-client/src/general_model.cpp index b78df167bdc061e7d706284bf19b8b2d06bad1fa..5d4f732fc19b605cb2e130c61a2e3cc0b2edc13a 100644 --- a/core/general-client/src/general_model.cpp +++ b/core/general-client/src/general_model.cpp @@ -395,6 +395,19 @@ int PredictorClient::numpy_predict( tensor->set_elem_type(1); const int float_shape_size = float_shape[vec_idx].size(); switch (float_shape_size) { + case 4: { + auto float_array = float_feed[vec_idx].unchecked<4>(); + for (ssize_t i = 0; i < float_array.shape(0); i++) { + for (ssize_t j = 0; j < float_array.shape(1); j++) { + for (ssize_t k = 0; k < float_array.shape(2); k++) { + for (ssize_t l = 0; l < float_array.shape(3); l++) { + tensor->add_float_data(float_array(i, j, k, l)); + } + } + } + } + break; + } case 3: { auto float_array = float_feed[vec_idx].unchecked<3>(); for (ssize_t i = 0; i < float_array.shape(0); i++) { @@ -415,6 +428,13 @@ int PredictorClient::numpy_predict( } break; } + case 1: { + auto float_array = float_feed[vec_idx].unchecked<1>(); + for (ssize_t i = 0; i < float_array.shape(0); i++) { + tensor->add_float_data(float_array(i)); + } + break; + } } vec_idx++; } @@ -436,7 +456,7 @@ int PredictorClient::numpy_predict( const int int_shape_size = int_shape[vec_idx].size(); switch (int_shape_size) { case 4: { - auto int_array = int_feed[vec_idx].unchecked<3>(); + auto int_array = int_feed[vec_idx].unchecked<4>(); for (ssize_t i = 0; i < int_array.shape(0); i++) { for (ssize_t j = 0; j < int_array.shape(1); j++) { for (ssize_t k = 0; k < int_array.shape(2); k++) { @@ -469,7 +489,7 @@ int PredictorClient::numpy_predict( break; } case 1: { - auto int_array = int_feed[vec_idx].unchecked<2>(); + auto int_array = int_feed[vec_idx].unchecked<1>(); for (ssize_t i = 0; i < int_array.shape(0); i++) { tensor->add_float_data(int_array(i)); }