提交 cd94df86 编写于 作者: T tensor-tang

fix load and refine

上级 8e271896
...@@ -251,7 +251,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs, ...@@ -251,7 +251,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input.set_lod(lod); input.set_lod(lod);
int idx = -1; int idx = -1;
if (config_.specify_input_name) { if (config_.specify_input_name) {
idx = feed_names_[inputs[i].name]; idx = feed_names_.at(inputs[i].name);
} else { } else {
idx = boost::get<int>(feeds_[i]->GetAttr("col")); idx = boost::get<int>(feeds_[i]->GetAttr("col"));
} }
......
...@@ -60,8 +60,7 @@ struct DataRecord { ...@@ -60,8 +60,7 @@ struct DataRecord {
} }
}; };
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data, void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data) {
int batch_size) {
PaddleTensor lod_word_tensor, lod_mention_tensor; PaddleTensor lod_word_tensor, lod_mention_tensor;
lod_word_tensor.name = "word"; lod_word_tensor.name = "word";
lod_mention_tensor.name = "mention"; lod_mention_tensor.name = "mention";
...@@ -100,7 +99,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) { ...@@ -100,7 +99,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1; int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size; LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
for (int bid = 0; bid < epoch; ++bid) { for (int bid = 0; bid < epoch; ++bid) {
PrepareInputs(&input_slots, &data, FLAGS_batch_size); PrepareInputs(&input_slots, &data);
(*inputs).emplace_back(input_slots); (*inputs).emplace_back(input_slots);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册