提交 e808b498 编写于 作者: M MRXLT

lod input support numpy array

上级 80e4ff42
...@@ -131,7 +131,7 @@ int GeneralReaderOp::inference() { ...@@ -131,7 +131,7 @@ int GeneralReaderOp::inference() {
lod_tensor.dtype = paddle::PaddleDType::FLOAT32; lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
} }
if (req->insts(0).tensor_array(i).shape(0) == -1) { if (model_config->_is_lod_feed[i]) {
lod_tensor.lod.resize(1); lod_tensor.lod.resize(1);
lod_tensor.lod[0].push_back(0); lod_tensor.lod[0].push_back(0);
VLOG(2) << "var[" << i << "] is lod_tensor"; VLOG(2) << "var[" << i << "] is lod_tensor";
...@@ -153,6 +153,7 @@ int GeneralReaderOp::inference() { ...@@ -153,6 +153,7 @@ int GeneralReaderOp::inference() {
// specify the memory needed for output tensor_vector // specify the memory needed for output tensor_vector
for (int i = 0; i < var_num; ++i) { for (int i = 0; i < var_num; ++i) {
if (out->at(i).lod.size() == 1) { if (out->at(i).lod.size() == 1) {
int tensor_size = 0;
for (int j = 0; j < batch_size; ++j) { for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i); const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = 0; int data_len = 0;
...@@ -162,15 +163,19 @@ int GeneralReaderOp::inference() { ...@@ -162,15 +163,19 @@ int GeneralReaderOp::inference() {
data_len = tensor.float_data_size(); data_len = tensor.float_data_size();
} }
VLOG(2) << "tensor size for var[" << i << "]: " << data_len; VLOG(2) << "tensor size for var[" << i << "]: " << data_len;
tensor_size += data_len;
int cur_len = out->at(i).lod[0].back(); int cur_len = out->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len; VLOG(2) << "current len: " << cur_len;
out->at(i).lod[0].push_back(cur_len + data_len); out->at(i).lod[0].push_back(cur_len + tensor.shape(0));
VLOG(2) << "new len: " << cur_len + data_len; VLOG(2) << "new len: " << cur_len + tensor.shape(0);
}
out->at(i).data.Resize(tensor_size * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back()};
for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
} }
out->at(i).data.Resize(out->at(i).lod[0].back() * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back(), 1};
VLOG(2) << "var[" << i VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back(); << "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else { } else {
......
...@@ -254,23 +254,16 @@ class Client(object): ...@@ -254,23 +254,16 @@ class Client(object):
for key in feed_i: for key in feed_i:
if key not in self.feed_names_: if key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key)) raise ValueError("Wrong feed name: {}.".format(key))
if not isinstance(feed_i[key], np.ndarray): #if not isinstance(feed_i[key], np.ndarray):
self.shape_check(feed_i, key) self.shape_check(feed_i, key)
if self.feed_types_[key] == int_type: if self.feed_types_[key] == int_type:
if i == 0: if i == 0:
int_feed_names.append(key) int_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
int_shape.append(list(feed_i[key].shape)) int_shape.append(list(feed_i[key].shape))
else: else:
int_shape.append(self.feed_shapes_[key]) int_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
#int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
int_slot.append(feed_i[key]) int_slot.append(feed_i[key])
self.has_numpy_input = True self.has_numpy_input = True
else: else:
...@@ -280,17 +273,10 @@ class Client(object): ...@@ -280,17 +273,10 @@ class Client(object):
if i == 0: if i == 0:
float_feed_names.append(key) float_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
float_shape.append(list(feed_i[key].shape)) float_shape.append(list(feed_i[key].shape))
else: else:
float_shape.append(self.feed_shapes_[key]) float_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray): if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
#float_slot.append(np.reshape(feed_i[key], (-1)).tolist())
float_slot.append(feed_i[key]) float_slot.append(feed_i[key])
self.has_numpy_input = True self.has_numpy_input = True
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册