提交 47b8040d 编写于 作者: H HexToString

add 2 lod fix

上级 2c6b21d5
...@@ -70,7 +70,6 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) { ...@@ -70,7 +70,6 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
// 每个lod型的fetchvar拷贝到对应的临时空间中 // 每个lod型的fetchvar拷贝到对应的临时空间中
// 最后再计算临时空间的总量,合并fetchvar和lod // 最后再计算临时空间的总量,合并fetchvar和lod
fetchvar_batch = 0; fetchvar_batch = 0;
} else { } else {
// 普通fetchvar情况,此时该Task总的fetchvar_batch = // 普通fetchvar情况,此时该Task总的fetchvar_batch =
// 输入的总的batch_size() // 输入的总的batch_size()
......
...@@ -290,12 +290,13 @@ struct Task { ...@@ -290,12 +290,13 @@ struct Task {
// which means error. // which means error.
if (feedvar_batch_size(feedvar_index) != 1 && total_feed_batch != 1) { if (feedvar_batch_size(feedvar_index) != 1 && total_feed_batch != 1) {
return false; return false;
} else { } else if(feedvar_batch_size(feedvar_index) != 1 && total_feed_batch == 1){
// which means feedvar shape[0] = 1. for(int temp = 0; temp < feedvar_index; ++temp){
// shape[0] does not change with batch set_feed_nobatch_index.insert(temp);
}
total_feed_batch = feedvar_batch_size(feedvar_index);
}else{
set_feed_nobatch_index.insert(feedvar_index); set_feed_nobatch_index.insert(feedvar_index);
total_feed_batch =
std::max(feedvar_batch_size(feedvar_index), total_feed_batch);
} }
} }
// 将lod feedvar index加入到vector中。 // 将lod feedvar index加入到vector中。
...@@ -400,6 +401,7 @@ struct Task { ...@@ -400,6 +401,7 @@ struct Task {
size_t lod_length = 0; size_t lod_length = 0;
size_t 2lod_length = 0; size_t 2lod_length = 0;
size_t total_shape0 = 0; size_t total_shape0 = 0;
size_t once_lod0_length = 0;
int lod_size = 1; int lod_size = 1;
size_t feedvar_index = vector_fetch_lod_index[index]; size_t feedvar_index = vector_fetch_lod_index[index];
// 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。 // 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。
...@@ -408,12 +410,13 @@ struct Task { ...@@ -408,12 +410,13 @@ struct Task {
lod_size = outLodTensorVector[taskmeta_index][index].lod.size(); lod_size = outLodTensorVector[taskmeta_index][index].lod.size();
data_length += data_length +=
outLodTensorVector[taskmeta_index][index].data.length(); outLodTensorVector[taskmeta_index][index].data.length();
lod_length += outLodTensorVector[taskmeta_index][index].lod[0].size(); once_lod0_length = outLodTensorVector[taskmeta_index][index].lod[0].size();
lod_length += once_lod0_length;
total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0]; total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0];
if (lod_size == 2) { if (lod_size == 2) {
2lod_length += 2lod_length +=
outLodTensorVector[taskmeta_index][index].lod[0] outLodTensorVector[taskmeta_index][index].lod[0]
[lod_length - 1]; [once_lod0_length - 1];
} }
} }
// 一次性扩容PaddleTensor中的data和lod // 一次性扩容PaddleTensor中的data和lod
...@@ -1162,18 +1165,15 @@ class BatchTasks { ...@@ -1162,18 +1165,15 @@ class BatchTasks {
// which means fetchvar shape[0] = 1. // which means fetchvar shape[0] = 1.
// shape[0] does not change with batch // shape[0] does not change with batch
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
_total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
} else if (_total_fetch_batch == 1) { } else if (_total_fetch_batch == 1) {
// 这时意味着,之前的fetchvar shape[0] 全部都= 1 // 这时意味着,之前的fetchvar shape[0] 全部都= 1
// 当前的fetchvar shape[0] > 1 // 当前的fetchvar shape[0] > 1
// 所以,之前的都是no_batch // 所以,之前的都是no_batch
for (size_t temp_index = fetchvar_index - 1; temp_index >= 0; for (size_t temp_index = 0; temp_index < fetchvar_index;
--temp_index) { --temp_index) {
set_fetch_nobatch_index.insert(fetchvar_index); set_fetch_nobatch_index.insert(fetchvar_index);
} }
_total_fetch_batch = _total_fetch_batch = fetchvar_batch_size(fetchvar_index);
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
} else { } else {
// which means error. // which means error.
return false; return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册