提交 9c1a964a 编写于 作者: B barrierye

fix bug that when the number of files is less than the number of threads, it will fetch nan

上级 f11b05ff
......@@ -195,8 +195,11 @@ void ExecutorThreadWorker::TrainFiles() {
thread_scope_->DropKids();
}
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt;
}
}
}
......
......@@ -42,9 +42,6 @@ class DataFeedDesc(object):
def set_batch_size(self, batch_size):
self.proto_desc.batch = batch_size
def get_slot(self, name):
return self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]]
def set_dense_slots(self, dense_slots_name):
for name in dense_slots_name:
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True
......@@ -156,6 +153,6 @@ class AsyncExecutor(object):
fetch = [fetch]
fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run_from_files(program_desc, data_feed, filelist, thread_num, fetch_var_names)
evaluation = self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names)
return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册