提交 9c1a964a 编写于 作者: B barrierye

fix bug that when the number of files is less than the number of threads, it will fetch nan

上级 f11b05ff
...@@ -195,9 +195,12 @@ void ExecutorThreadWorker::TrainFiles() { ...@@ -195,9 +195,12 @@ void ExecutorThreadWorker::TrainFiles() {
thread_scope_->DropKids(); thread_scope_->DropKids();
} }
if (batch_cnt) {
// when the number of files is less than the number of threads
for (int i = 0; i < fetch_var_num; ++i) { for (int i = 0; i < fetch_var_num; ++i) {
fetch_values_[i] = fetch_values_[i] / batch_cnt; fetch_values_[i] = fetch_values_[i] / batch_cnt;
} }
}
} }
......
...@@ -42,9 +42,6 @@ class DataFeedDesc(object): ...@@ -42,9 +42,6 @@ class DataFeedDesc(object):
def set_batch_size(self, batch_size): def set_batch_size(self, batch_size):
self.proto_desc.batch = batch_size self.proto_desc.batch = batch_size
def get_slot(self, name):
return self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]]
def set_dense_slots(self, dense_slots_name): def set_dense_slots(self, dense_slots_name):
for name in dense_slots_name: for name in dense_slots_name:
self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True self.proto_desc.multi_slot_desc.slots[self.__name_to_index[name]].dense = True
...@@ -156,6 +153,6 @@ class AsyncExecutor(object): ...@@ -156,6 +153,6 @@ class AsyncExecutor(object):
fetch = [fetch] fetch = [fetch]
fetch_var_names = [var.name for var in fetch] fetch_var_names = [var.name for var in fetch]
evaluation = self.executor.run_from_files(program_desc, data_feed, filelist, thread_num, fetch_var_names) evaluation = self.executor.run_from_files(program_desc, data_feed.desc(), filelist, thread_num, fetch_var_names)
return evaluation return evaluation
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册