提交 cd94df86 编写于 作者: T tensor-tang

fix load and refine

上级 8e271896
......@@ -251,7 +251,7 @@ bool AnalysisPredictor::SetFeed(const std::vector<PaddleTensor> &inputs,
input.set_lod(lod);
int idx = -1;
if (config_.specify_input_name) {
idx = feed_names_[inputs[i].name];
idx = feed_names_.at(inputs[i].name);
} else {
idx = boost::get<int>(feeds_[i]->GetAttr("col"));
}
......
......@@ -60,8 +60,7 @@ struct DataRecord {
}
};
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data,
int batch_size) {
void PrepareInputs(std::vector<PaddleTensor> *input_slots, DataRecord *data) {
PaddleTensor lod_word_tensor, lod_mention_tensor;
lod_word_tensor.name = "word";
lod_mention_tensor.name = "mention";
......@@ -100,7 +99,7 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs) {
int epoch = FLAGS_test_all_data ? data.num_samples / FLAGS_batch_size : 1;
LOG(INFO) << "number of samples: " << epoch * FLAGS_batch_size;
for (int bid = 0; bid < epoch; ++bid) {
PrepareInputs(&input_slots, &data, FLAGS_batch_size);
PrepareInputs(&input_slots, &data);
(*inputs).emplace_back(input_slots);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册