提交 eea54944 编写于 作者: F frankwhzhang

fix model.py

上级 2178f662
......@@ -139,23 +139,31 @@ class Model(object):
def net(self, is_infer=False):
return None
def train_net(self):
input_data = self.input_data(is_infer=False)
self._data_var = input_data
# if use dataset _data_loader not used
def _construct_reader(self, is_infer=False):
if is_infer:
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
else:
dataset_class = envs.get_global_env("dataset_class", None,
"train.reader")
if dataset_class == "DataLoader":
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
def train_net(self):
input_data = self.input_data(is_infer=False)
self._data_var = input_data
self._construct_reader(is_infer=False)
self.net(input_data, is_infer=False)
def infer_net(self):
input_data = self.input_data(is_infer=True)
self._infer_data_var = input_data
self._infer_data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._infer_data_var,
capacity=64,
use_double_buffer=False,
iterable=False)
self._construct_reader(is_infer=True)
self.net(input_data, is_infer=True)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册