提交 05f1b65d 编写于 作者: T Tao Luo

simplify prepere_input in analyzer_test

test=develop
上级 dc8eca82
...@@ -113,6 +113,16 @@ static void TensorAssignData(PaddleTensor *tensor, ...@@ -113,6 +113,16 @@ static void TensorAssignData(PaddleTensor *tensor,
} }
} }
template <typename T>
static void TensorAssignData(PaddleTensor *tensor,
const std::vector<std::vector<T>> &data,
const std::vector<size_t> &lod) {
int size = lod[lod.size() - 1];
tensor->shape.assign({size, 1});
tensor->lod.assign({lod});
TensorAssignData(tensor, data);
}
template <typename T> template <typename T>
static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor, static int ZeroCopyTensorAssignData(ZeroCopyTensor *tensor,
const std::vector<std::vector<T>> &data) { const std::vector<std::vector<T>> &data) {
......
...@@ -98,10 +98,8 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -98,10 +98,8 @@ void GetOneBatch(std::vector<PaddleTensor> *input_slots, DataRecord *data,
auto one_batch = data->NextBatch(); auto one_batch = data->NextBatch();
PaddleTensor input_tensor; PaddleTensor input_tensor;
input_tensor.name = "word"; input_tensor.name = "word";
input_tensor.shape.assign({static_cast<int>(one_batch.data.size()), 1});
input_tensor.lod.assign({one_batch.lod});
input_tensor.dtype = PaddleDType::INT64; input_tensor.dtype = PaddleDType::INT64;
TensorAssignData<int64_t>(&input_tensor, {one_batch.data}); TensorAssignData<int64_t>(&input_tensor, {one_batch.data}, one_batch.lod);
PADDLE_ENFORCE_EQ(batch_size, static_cast<int>(one_batch.lod.size() - 1)); PADDLE_ENFORCE_EQ(batch_size, static_cast<int>(one_batch.lod.size() - 1));
input_slots->assign({input_tensor}); input_slots->assign({input_tensor});
} }
......
...@@ -80,15 +80,11 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -80,15 +80,11 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
lod_query_tensor.name = "left"; lod_query_tensor.name = "left";
lod_title_tensor.name = "right"; lod_title_tensor.name = "right";
auto one_batch = data->NextBatch(); auto one_batch = data->NextBatch();
int size1 = one_batch.lod1[one_batch.lod1.size() - 1]; // token batch size
int size2 = one_batch.lod2[one_batch.lod2.size() - 1]; // token batch size
lod_query_tensor.shape.assign({size1, 1});
lod_query_tensor.lod.assign({one_batch.lod1});
lod_title_tensor.shape.assign({size2, 1});
lod_title_tensor.lod.assign({one_batch.lod2});
// assign data // assign data
TensorAssignData<int64_t>(&lod_query_tensor, one_batch.query_data_all); TensorAssignData<int64_t>(&lod_query_tensor, one_batch.query_data_all,
TensorAssignData<int64_t>(&lod_title_tensor, one_batch.title_data_all); one_batch.lod1);
TensorAssignData<int64_t>(&lod_title_tensor, one_batch.title_data_all,
one_batch.lod2);
// Set inputs. // Set inputs.
input_slots->assign({lod_query_tensor, lod_title_tensor}); input_slots->assign({lod_query_tensor, lod_title_tensor});
for (auto &tensor : *input_slots) { for (auto &tensor : *input_slots) {
......
...@@ -78,14 +78,11 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -78,14 +78,11 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
lod_word_tensor.name = "word"; lod_word_tensor.name = "word";
lod_mention_tensor.name = "mention"; lod_mention_tensor.name = "mention";
auto one_batch = data->NextBatch(); auto one_batch = data->NextBatch();
int size = one_batch.lod[one_batch.lod.size() - 1]; // token batch size
lod_word_tensor.shape.assign({size, 1});
lod_word_tensor.lod.assign({one_batch.lod});
lod_mention_tensor.shape.assign({size, 1});
lod_mention_tensor.lod.assign({one_batch.lod});
// assign data // assign data
TensorAssignData<int64_t>(&lod_word_tensor, one_batch.word_data_all); TensorAssignData<int64_t>(&lod_word_tensor, one_batch.word_data_all,
TensorAssignData<int64_t>(&lod_mention_tensor, one_batch.mention_data_all); one_batch.lod);
TensorAssignData<int64_t>(&lod_mention_tensor, one_batch.mention_data_all,
one_batch.lod);
// Set inputs. // Set inputs.
input_slots->assign({lod_word_tensor, lod_mention_tensor}); input_slots->assign({lod_word_tensor, lod_mention_tensor});
for (auto &tensor : *input_slots) { for (auto &tensor : *input_slots) {
......
...@@ -109,24 +109,14 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, ...@@ -109,24 +109,14 @@ void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
title3_tensor.name = "title3"; title3_tensor.name = "title3";
l1_tensor.name = "l1"; l1_tensor.name = "l1";
auto one_batch = data->NextBatch(); auto one_batch = data->NextBatch();
int title1_size = one_batch.title1_lod[one_batch.title1_lod.size() - 1];
title1_tensor.shape.assign({title1_size, 1});
title1_tensor.lod.assign({one_batch.title1_lod});
int title2_size = one_batch.title2_lod[one_batch.title2_lod.size() - 1];
title2_tensor.shape.assign({title2_size, 1});
title2_tensor.lod.assign({one_batch.title2_lod});
int title3_size = one_batch.title3_lod[one_batch.title3_lod.size() - 1];
title3_tensor.shape.assign({title3_size, 1});
title3_tensor.lod.assign({one_batch.title3_lod});
int l1_size = one_batch.l1_lod[one_batch.l1_lod.size() - 1];
l1_tensor.shape.assign({l1_size, 1});
l1_tensor.lod.assign({one_batch.l1_lod});
// assign data // assign data
TensorAssignData<int64_t>(&title1_tensor, one_batch.title1); TensorAssignData<int64_t>(&title1_tensor, one_batch.title1,
TensorAssignData<int64_t>(&title2_tensor, one_batch.title2); one_batch.title1_lod);
TensorAssignData<int64_t>(&title3_tensor, one_batch.title3); TensorAssignData<int64_t>(&title2_tensor, one_batch.title2,
TensorAssignData<int64_t>(&l1_tensor, one_batch.l1); one_batch.title2_lod);
TensorAssignData<int64_t>(&title3_tensor, one_batch.title3,
one_batch.title3_lod);
TensorAssignData<int64_t>(&l1_tensor, one_batch.l1, one_batch.l1_lod);
// Set inputs. // Set inputs.
input_slots->assign({title1_tensor, title2_tensor, title3_tensor, l1_tensor}); input_slots->assign({title1_tensor, title2_tensor, title3_tensor, l1_tensor});
for (auto &tensor : *input_slots) { for (auto &tensor : *input_slots) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册