提交 b6592c45 编写于 作者: M MRXLT

bug fix

上级 adb10d1a
...@@ -395,6 +395,19 @@ int PredictorClient::numpy_predict( ...@@ -395,6 +395,19 @@ int PredictorClient::numpy_predict(
tensor->set_elem_type(1); tensor->set_elem_type(1);
const int float_shape_size = float_shape[vec_idx].size(); const int float_shape_size = float_shape[vec_idx].size();
switch (float_shape_size) { switch (float_shape_size) {
case 4: {
auto float_array = float_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
for (ssize_t j = 0; j < float_array.shape(1); j++) {
for (ssize_t k = 0; k < float_array.shape(2); k++) {
for (ssize_t l = 0; l < float_array.shape(3); l++) {
tensor->add_float_data(float_array(i, j, k, l));
}
}
}
}
break;
}
case 3: { case 3: {
auto float_array = float_feed[vec_idx].unchecked<3>(); auto float_array = float_feed[vec_idx].unchecked<3>();
for (ssize_t i = 0; i < float_array.shape(0); i++) { for (ssize_t i = 0; i < float_array.shape(0); i++) {
...@@ -415,6 +428,13 @@ int PredictorClient::numpy_predict( ...@@ -415,6 +428,13 @@ int PredictorClient::numpy_predict(
} }
break; break;
} }
case 1: {
auto float_array = float_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
tensor->add_float_data(float_array(i));
}
break;
}
} }
vec_idx++; vec_idx++;
} }
...@@ -436,7 +456,7 @@ int PredictorClient::numpy_predict( ...@@ -436,7 +456,7 @@ int PredictorClient::numpy_predict(
const int int_shape_size = int_shape[vec_idx].size(); const int int_shape_size = int_shape[vec_idx].size();
switch (int_shape_size) { switch (int_shape_size) {
case 4: { case 4: {
auto int_array = int_feed[vec_idx].unchecked<3>(); auto int_array = int_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < int_array.shape(0); i++) { for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) { for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) { for (ssize_t k = 0; k < int_array.shape(2); k++) {
...@@ -469,7 +489,7 @@ int PredictorClient::numpy_predict( ...@@ -469,7 +489,7 @@ int PredictorClient::numpy_predict(
break; break;
} }
case 1: { case 1: {
auto int_array = int_feed[vec_idx].unchecked<2>(); auto int_array = int_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < int_array.shape(0); i++) { for (ssize_t i = 0; i < int_array.shape(0); i++) {
tensor->add_float_data(int_array(i)); tensor->add_float_data(int_array(i));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册