提交 ac0a3050 编写于 作者: M MRXLT

lod input support numpy array

上级 237cae71
......@@ -131,7 +131,7 @@ int GeneralReaderOp::inference() {
lod_tensor.dtype = paddle::PaddleDType::FLOAT32;
}
if (req->insts(0).tensor_array(i).shape(0) == -1) {
if (model_config->_is_lod_feed[i]) {
lod_tensor.lod.resize(1);
lod_tensor.lod[0].push_back(0);
VLOG(2) << "var[" << i << "] is lod_tensor";
......@@ -153,6 +153,7 @@ int GeneralReaderOp::inference() {
// specify the memory needed for output tensor_vector
for (int i = 0; i < var_num; ++i) {
if (out->at(i).lod.size() == 1) {
int tensor_size = 0;
for (int j = 0; j < batch_size; ++j) {
const Tensor &tensor = req->insts(j).tensor_array(i);
int data_len = 0;
......@@ -162,15 +163,19 @@ int GeneralReaderOp::inference() {
data_len = tensor.float_data_size();
}
VLOG(2) << "tensor size for var[" << i << "]: " << data_len;
tensor_size += data_len;
int cur_len = out->at(i).lod[0].back();
VLOG(2) << "current len: " << cur_len;
out->at(i).lod[0].push_back(cur_len + data_len);
VLOG(2) << "new len: " << cur_len + data_len;
out->at(i).lod[0].push_back(cur_len + tensor.shape(0));
VLOG(2) << "new len: " << cur_len + tensor.shape(0);
}
out->at(i).data.Resize(tensor_size * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back()};
for (int j = 1; j < req->insts(0).tensor_array(i).shape_size(); ++j) {
out->at(i).shape.push_back(req->insts(0).tensor_array(i).shape(j));
}
out->at(i).data.Resize(out->at(i).lod[0].back() * elem_size[i]);
out->at(i).shape = {out->at(i).lod[0].back(), 1};
VLOG(2) << "var[" << i
<< "] is lod_tensor and len=" << out->at(i).lod[0].back();
} else {
......
......@@ -254,23 +254,16 @@ class Client(object):
for key in feed_i:
if key not in self.feed_names_:
raise ValueError("Wrong feed name: {}.".format(key))
if not isinstance(feed_i[key], np.ndarray):
self.shape_check(feed_i, key)
#if not isinstance(feed_i[key], np.ndarray):
self.shape_check(feed_i, key)
if self.feed_types_[key] == int_type:
if i == 0:
int_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
int_shape.append(list(feed_i[key].shape))
else:
int_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
#int_slot.append(np.reshape(feed_i[key], (-1)).tolist())
int_slot.append(feed_i[key])
self.has_numpy_input = True
else:
......@@ -280,17 +273,10 @@ class Client(object):
if i == 0:
float_feed_names.append(key)
if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
float_shape.append(list(feed_i[key].shape))
else:
float_shape.append(self.feed_shapes_[key])
if isinstance(feed_i[key], np.ndarray):
if key in self.lod_tensor_set:
raise ValueError(
"LodTensor var can not be ndarray type.")
#float_slot.append(np.reshape(feed_i[key], (-1)).tolist())
float_slot.append(feed_i[key])
self.has_numpy_input = True
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册