提交 a4dddae2 编写于 作者: H HexToString

fix little bug

上级 1a355b65
......@@ -31,32 +31,32 @@ namespace im {
namespace bsf {
template <typename InItemT, typename OutItemT>
bool Task<InItemT, OutItemT>::task_fetch_init(BatchTasks<TaskT>& baskTask) {
bool Task<InItemT, OutItemT>::task_fetch_init(BatchTasks<TaskT>& batchTask) {
// 双检锁,减少加锁的粒度
if (!fetch_init) {
if (taskmeta_num > 1) {
// 对于task被拆分为多个taskmeta,需要加锁。
AutoMutex lock(task_mut);
task_fetch_create(baskTask);
task_fetch_create(batchTask);
} else {
// 对于task只有1个taskmeta,不需要加锁。
task_fetch_create(baskTask);
task_fetch_create(batchTask);
}
}
return true;
}
template <typename InItemT, typename OutItemT>
bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& baskTask) {
bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& batchTask) {
if (!fetch_init) {
vector_fetch_lod_index = baskTask.vector_fetch_lod_index;
set_fetch_nobatch_index = baskTask.set_fetch_nobatch_index;
vector_fetch_lod_index = batchTask.vector_fetch_lod_index;
set_fetch_nobatch_index = batchTask.set_fetch_nobatch_index;
OutVectorT taskMetaOutLodTensor;
size_t fetchvar_num = baskTask._batch_out.size();
size_t fetchvar_num = batchTask._batch_out.size();
for (size_t fetchvar_index = 0; fetchvar_index < fetchvar_num;
++fetchvar_index) {
size_t fetchvar_bytesize_index =
baskTask.fetchvar_bytesize(fetchvar_index);
batchTask.fetchvar_bytesize(fetchvar_index);
size_t fetchvar_batch = 0;
// 1. nobatch fetchvar情况
if (set_fetch_nobatch_index.size() > 0 &&
......@@ -79,14 +79,14 @@ bool Task<InItemT, OutItemT>::task_fetch_create(BatchTasks<TaskT>& baskTask) {
fetchvar_batch = batch_size();
}
paddle::PaddleTensor tensor_out;
tensor_out.name = baskTask._batch_out[fetchvar_index].name;
tensor_out.name = batchTask._batch_out[fetchvar_index].name;
tensor_out.dtype =
paddle::PaddleDType(baskTask._batch_out[fetchvar_index].dtype);
tensor_out.shape = baskTask._batch_out[fetchvar_index].shape;
paddle::PaddleDType(batchTask._batch_out[fetchvar_index].dtype);
tensor_out.shape = batchTask._batch_out[fetchvar_index].shape;
tensor_out.shape[0] = fetchvar_batch;
if (fetchvar_batch != 0) {
// 此时 lod 为空。
tensor_out.lod = baskTask._batch_out[fetchvar_index].lod;
tensor_out.lod = batchTask._batch_out[fetchvar_index].lod;
// resize all batch memory at one time
size_t databuf_size = fetchvar_batch * fetchvar_bytesize_index;
tensor_out.data.Resize(databuf_size);
......
......@@ -319,6 +319,7 @@ struct Task {
for (size_t index = 0; index < vector_fetch_lod_index.size(); ++index) {
size_t data_length = 0;
size_t lod_length = 0;
size_t total_shape0 = 0;
size_t feedvar_index = vector_fetch_lod_index[index];
// 由于PaddleTensor的resize实现,是每次都会清空,所以必须先统计总长度。
for (size_t taskmeta_index = 0; taskmeta_index < taskmeta_num;
......@@ -326,6 +327,7 @@ struct Task {
data_length +=
outLodTensorVector[taskmeta_index][index].data.length();
lod_length += outLodTensorVector[taskmeta_index][index].lod[0].size();
total_shape0 += outLodTensorVector[taskmeta_index][index].shape[0];
}
// 一次性扩容PaddleTensor中的data和lod
paddle::PaddleTensor& fetchVarTensor = (*outVectorT_ptr)[feedvar_index];
......@@ -368,8 +370,8 @@ struct Task {
return true;
}
bool task_fetch_init(BatchTasks<TaskT>& baskTask);
bool task_fetch_create(BatchTasks<TaskT>& baskTask);
bool task_fetch_init(BatchTasks<TaskT>& batchTask);
bool task_fetch_create(BatchTasks<TaskT>& batchTask);
};
// `Several Task` or `part of batch in Task` can be a TaskMeta.
......@@ -788,6 +790,7 @@ class BatchTasks {
paddle::PaddleTensor& fetchVarTensor =
task->outLodTensorVector[taskmeta_index][fetch_lod_index];
size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length;
fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() +
......@@ -813,6 +816,7 @@ class BatchTasks {
paddle::PaddleTensor& fetchVarTensor =
(*task->outVectorT_ptr)[fetchvar_index];
size_t length = fetchvar_bytesize_index * shape0_length;
fetchVarTensor.shape[0] = shape0_length;
fetchVarTensor.data.Resize(length);
void* dst_ptr = fetchVarTensor.data.data();
void* source_ptr = _batch_out[fetchvar_index].data.data() +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册