提交 b6592c45 编写于 作者: M MRXLT

bug fix

上级 adb10d1a
......@@ -395,6 +395,19 @@ int PredictorClient::numpy_predict(
tensor->set_elem_type(1);
const int float_shape_size = float_shape[vec_idx].size();
switch (float_shape_size) {
case 4: {
auto float_array = float_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
for (ssize_t j = 0; j < float_array.shape(1); j++) {
for (ssize_t k = 0; k < float_array.shape(2); k++) {
for (ssize_t l = 0; l < float_array.shape(3); l++) {
tensor->add_float_data(float_array(i, j, k, l));
}
}
}
}
break;
}
case 3: {
auto float_array = float_feed[vec_idx].unchecked<3>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
......@@ -415,6 +428,13 @@ int PredictorClient::numpy_predict(
}
break;
}
case 1: {
auto float_array = float_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < float_array.shape(0); i++) {
tensor->add_float_data(float_array(i));
}
break;
}
}
vec_idx++;
}
......@@ -436,7 +456,7 @@ int PredictorClient::numpy_predict(
const int int_shape_size = int_shape[vec_idx].size();
switch (int_shape_size) {
case 4: {
auto int_array = int_feed[vec_idx].unchecked<3>();
auto int_array = int_feed[vec_idx].unchecked<4>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
for (ssize_t j = 0; j < int_array.shape(1); j++) {
for (ssize_t k = 0; k < int_array.shape(2); k++) {
......@@ -469,7 +489,7 @@ int PredictorClient::numpy_predict(
break;
}
case 1: {
auto int_array = int_feed[vec_idx].unchecked<2>();
auto int_array = int_feed[vec_idx].unchecked<1>();
for (ssize_t i = 0; i < int_array.shape(0); i++) {
tensor->add_float_data(int_array(i));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册