提交 eea54944 编写于 作者: F frankwhzhang

fix model.py

上级 2178f662
...@@ -139,23 +139,31 @@ class Model(object): ...@@ -139,23 +139,31 @@ class Model(object):
def net(self, is_infer=False): def net(self, is_infer=False):
return None return None
def train_net(self): def _construct_reader(self, is_infer=False):
input_data = self.input_data(is_infer=False) if is_infer:
self._data_var = input_data self._infer_data_loader = fluid.io.DataLoader.from_generator(
# if use dataset _data_loader not used feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
else:
dataset_class = envs.get_global_env("dataset_class", None,
"train.reader")
if dataset_class == "DataLoader":
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, feed_list=self._data_var,
capacity=64, capacity=64,
use_double_buffer=False, use_double_buffer=False,
iterable=False) iterable=False)
def train_net(self):
input_data = self.input_data(is_infer=False)
self._data_var = input_data
self._construct_reader(is_infer=False)
self.net(input_data, is_infer=False) self.net(input_data, is_infer=False)
def infer_net(self): def infer_net(self):
input_data = self.input_data(is_infer=True) input_data = self.input_data(is_infer=True)
self._infer_data_var = input_data self._infer_data_var = input_data
self._infer_data_loader = fluid.io.DataLoader.from_generator( self._construct_reader(is_infer=True)
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
self.net(input_data, is_infer=True) self.net(input_data, is_infer=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册