提交 983b264b 编写于 作者: H HexToString

add 2 lod and fix

上级 47b8040d
...@@ -290,12 +290,13 @@ struct Task { ...@@ -290,12 +290,13 @@ struct Task {
// which means error. // which means error.
if (feedvar_batch_size(feedvar_index) != 1 && total_feed_batch != 1) { if (feedvar_batch_size(feedvar_index) != 1 && total_feed_batch != 1) {
return false; return false;
} else if(feedvar_batch_size(feedvar_index) != 1 && total_feed_batch == 1){ } else if (feedvar_batch_size(feedvar_index) != 1 &&
for(int temp = 0; temp < feedvar_index; ++temp){ total_feed_batch == 1) {
for (int temp = 0; temp < feedvar_index; ++temp) {
set_feed_nobatch_index.insert(temp); set_feed_nobatch_index.insert(temp);
} }
total_feed_batch = feedvar_batch_size(feedvar_index); total_feed_batch = feedvar_batch_size(feedvar_index);
}else{ } else {
set_feed_nobatch_index.insert(feedvar_index); set_feed_nobatch_index.insert(feedvar_index);
} }
} }
...@@ -362,23 +363,23 @@ struct Task { ...@@ -362,23 +363,23 @@ struct Task {
size_t shape0_end = (*inVectorT_ptr)[feedvar_index].lod[0][end_batch]; size_t shape0_end = (*inVectorT_ptr)[feedvar_index].lod[0][end_batch];
feature_vector = {{shape0_start, shape0_end}, feed_lod_vector}; feature_vector = {{shape0_start, shape0_end}, feed_lod_vector};
} else if (lod_size == 2) { } else if (lod_size == 2) {
size_t 2level_lod_start_index = size_t level2_lod_start_index =
(*inVectorT_ptr)[feedvar_index].lod[0][start_batch]; (*inVectorT_ptr)[feedvar_index].lod[0][start_batch];
size_t 2level_lod_end_index = size_t level2_lod_end_index =
(*inVectorT_ptr)[feedvar_index].lod[0][end_batch]; (*inVectorT_ptr)[feedvar_index].lod[0][end_batch];
int 2level_lod_size = 2level_lod_end_index - 2level_lod_start_index; int level2_lod_size = level2_lod_end_index - level2_lod_start_index;
std::vector<size_t> feed_2level_lod_vector(2level_lod_size); std::vector<size_t> feed_2level_lod_vector(level2_lod_size);
for (size_t 2lod_index = 2level_lod_start_index + 1, vector_index = 0; for (size_t lod2_index = level2_lod_start_index + 1, vector_index = 0;
2lod_index < 2level_lod_end_index + 1; lod2_index < level2_lod_end_index + 1;
++vector_index, ++2lod_index) { ++vector_index, ++lod2_index) {
feed_2level_lod_vector[vector_index] = feed_2level_lod_vector[vector_index] =
(*inVectorT_ptr)[feedvar_index].lod[1][2lod_index] - (*inVectorT_ptr)[feedvar_index].lod[1][lod2_index] -
(*inVectorT_ptr)[feedvar_index].lod[1][2level_lod_start_index]; (*inVectorT_ptr)[feedvar_index].lod[1][level2_lod_start_index];
} }
size_t shape0_start = size_t shape0_start =
(*inVectorT_ptr)[feedvar_index].lod[1][2level_lod_start_index]; (*inVectorT_ptr)[feedvar_index].lod[1][level2_lod_start_index];
size_t shape0_end = size_t shape0_end =
(*inVectorT_ptr)[feedvar_index].lod[1][2level_lod_end_index]; (*inVectorT_ptr)[feedvar_index].lod[1][level2_lod_end_index];
feature_vector = {{shape0_start, shape0_end}, feature_vector = {{shape0_start, shape0_end},
feed_lod_vector, feed_lod_vector,
{}, {},
...@@ -399,7 +400,7 @@ struct Task { ...@@ -399,7 +400,7 @@ struct Task {
for (size_t index = 0; index < vector_fetch_lod_index.size(); ++index) { for (size_t index = 0; index < vector_fetch_lod_index.size(); ++index) {
size_t data_length = 0; size_t data_length = 0;
size_t lod_length = 0; size_t lod_length = 0;
size_t 2lod_length = 0; size_t lod2_length = 0;
size_t total_shape0 = 0; size_t total_shape0 = 0;
size_t once_lod0_length = 0; size_t once_lod0_length = 0;
int lod_size = 1; int lod_size = 1;
...@@ -410,13 +411,13 @@ struct Task { ...@@ -410,13 +411,13 @@ struct Task {
lod_size = outLodTensorVector[taskmeta_index][index].lod.size(); lod_size = outLodTensorVector[taskmeta_index][index].lod.size();
data_length += data_length +=
outLodTensorVector[taskmeta_index][index].data.length(); outLodTensorVector[taskmeta_index][index].data.length();
once_lod0_length = outLodTensorVector[taskmeta_index][index].lod[0].size(); once_lod0_length =
outLodTensorVector[taskmeta_index][index].lod[0].size();
lod_length += once_lod0_length; lod_length += once_lod0_length;
total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0]; total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0];
if (lod_size == 2) { if (lod_size == 2) {
2lod_length += lod2_length += outLodTensorVector[taskmeta_index][index]
outLodTensorVector[taskmeta_index][index].lod[0] .lod[0][once_lod0_length - 1];
[once_lod0_length - 1];
} }
} }
// 一次性扩容PaddleTensor中的data和lod // 一次性扩容PaddleTensor中的data和lod
...@@ -441,13 +442,13 @@ struct Task { ...@@ -441,13 +442,13 @@ struct Task {
} else if (fetchVarTensor.lod[1].size() <= 0) { } else if (fetchVarTensor.lod[1].size() <= 0) {
fetchVarTensor.lod[1].push_back(0); fetchVarTensor.lod[1].push_back(0);
} }
fetchVarTensor.lod[1].resize(2lod_length + 1, 0); fetchVarTensor.lod[1].resize(lod2_length + 1, 0);
} }
// //
size_t data_length_offset = 0; size_t data_length_offset = 0;
size_t lod_length_offset = 0; size_t lod_length_offset = 0;
size_t 2lod_length_offset = 0; size_t lod2_length_offset = 0;
size_t once_data_length = 0; size_t once_data_length = 0;
size_t once_lod_length = 0; size_t once_lod_length = 0;
size_t once_2lod_length = 0; size_t once_2lod_length = 0;
...@@ -473,15 +474,15 @@ struct Task { ...@@ -473,15 +474,15 @@ struct Task {
lod_length_offset++; lod_length_offset++;
} }
if (lod_size == 2) { if (lod_size == 2) {
size_t last_2lod_value = fetchVarTensor.lod[1][2lod_length_offset]; size_t last_2lod_value = fetchVarTensor.lod[1][lod2_length_offset];
once_2lod_length = once_2lod_length =
outLodTensorVector[taskmeta_index][index].lod[1].size(); outLodTensorVector[taskmeta_index][index].lod[1].size();
for (size_t once_index = 0; once_index < once_2lod_length; for (size_t once_index = 0; once_index < once_2lod_length;
++once_index) { ++once_index) {
fetchVarTensor.lod[1][2lod_length_offset + 1] = fetchVarTensor.lod[1][lod2_length_offset + 1] =
last_2lod_value + last_2lod_value +
outLodTensorVector[taskmeta_index][index].lod[1][once_index]; outLodTensorVector[taskmeta_index][index].lod[1][once_index];
2lod_length_offset++; lod2_length_offset ++;
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册