diff --git a/core/model.py b/core/model.py index 9f9e3721185d761c37721079722ab809e7bc8c03..fd0b22fe005a0d001f64687f412f3cc88a35ef72 100755 --- a/core/model.py +++ b/core/model.py @@ -139,23 +139,31 @@ class Model(object): def net(self, is_infer=False): return None + def _construct_reader(self, is_infer=False): + if is_infer: + self._infer_data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._infer_data_var, + capacity=64, + use_double_buffer=False, + iterable=False) + else: + dataset_class = envs.get_global_env("dataset_class", None, + "train.reader") + if dataset_class == "DataLoader": + self._data_loader = fluid.io.DataLoader.from_generator( + feed_list=self._data_var, + capacity=64, + use_double_buffer=False, + iterable=False) + def train_net(self): input_data = self.input_data(is_infer=False) self._data_var = input_data - # if use dataset _data_loader not used - self._data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._data_var, - capacity=64, - use_double_buffer=False, - iterable=False) + self._construct_reader(is_infer=False) self.net(input_data, is_infer=False) def infer_net(self): input_data = self.input_data(is_infer=True) self._infer_data_var = input_data - self._infer_data_loader = fluid.io.DataLoader.from_generator( - feed_list=self._infer_data_var, - capacity=64, - use_double_buffer=False, - iterable=False) + self._construct_reader(is_infer=True) self.net(input_data, is_infer=True)