提交 47b8040d 编写于 作者: H HexToString

add 2 lod fix

上级 2c6b21d5
......@@ -70,7 +70,6 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
// 每个lod型的fetchvar拷贝到对应的临时空间中
// 最后再计算临时空间的总量,合并fetchvar和lod
fetchvar_batch = 0;
} else {
// 普通fetchvar情况,此时该Task总的fetchvar_batch =
// 输入的总的batch_size()
......
......@@ -290,12 +290,13 @@ struct Task {
// which means error.
if (feedvar_batch_size(feedvar_index) != 1 && total_feed_batch != 1) {
return false;
} else {
// which means feedvar shape[0] = 1.
// shape[0] does not change with batch
} else if(feedvar_batch_size(feedvar_index) != 1 && total_feed_batch == 1){
for(int temp = 0; temp < feedvar_index; ++temp){
set_feed_nobatch_index.insert(temp);
}
total_feed_batch = feedvar_batch_size(feedvar_index);
}else{
set_feed_nobatch_index.insert(feedvar_index);
total_feed_batch =
std::max(feedvar_batch_size(feedvar_index), total_feed_batch);
}
}
// 将lod feedvar index加入到vector中。
......@@ -400,6 +401,7 @@ struct Task {
size_t lod_length = 0;
size_t 2lod_length = 0;
size_t total_shape0 = 0;
size_t once_lod0_length = 0;
int lod_size = 1;
size_t feedvar_index = vector_fetch_lod_index[index];
// 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。
......@@ -408,12 +410,13 @@ struct Task {
lod_size = outLodTensorVector[taskmeta_index][index].lod.size();
data_length +=
outLodTensorVector[taskmeta_index][index].data.length();
lod_length += outLodTensorVector[taskmeta_index][index].lod[0].size();
once_lod0_length = outLodTensorVector[taskmeta_index][index].lod[0].size();
lod_length += once_lod0_length;
total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0];
if (lod_size == 2) {
2lod_length +=
outLodTensorVector[taskmeta_index][index].lod[0]
[lod_length - 1];
[once_lod0_length - 1];
}
}
// 一次性扩容PaddleTensor中的data和lod
......@@ -1162,18 +1165,15 @@ class BatchTasks {
// which means fetchvar shape[0] = 1.
// shape[0] does not change with batch
set_fetch_nobatch_index.insert(fetchvar_index);
_total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
} else if (_total_fetch_batch == 1) {
// 这时意味着,之前的fetchvar shape[0] 全部都= 1
// 当前的fetchvar shape[0] > 1
// 所以,之前的都是no_batch
for (size_t temp_index = fetchvar_index - 1; temp_index >= 0;
for (size_t temp_index = 0; temp_index < fetchvar_index;
--temp_index) {
set_fetch_nobatch_index.insert(fetchvar_index);
}
_total_fetch_batch =
std::max(fetchvar_batch_size(fetchvar_index), _total_fetch_batch);
_total_fetch_batch = fetchvar_batch_size(fetchvar_index);
} else {
// which means error.
return false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册