提交 dd378956 编写于 作者: M malin10

add infer

上级 2fed2cdf
......@@ -16,7 +16,10 @@ class Model(object):
self._cost = None
self._metrics = {}
self._data_var = []
self._infer_data_var = []
self._infer_results = {}
self._data_loader = None
self._infer_data_loader = None
self._fetch_interval = 20
self._namespace = "train.model"
self._platform = envs.get_platform()
......@@ -24,6 +27,12 @@ class Model(object):
def get_inputs(self):
return self._data_var
def get_infer_inputs(self):
return self._infer_data_var
def get_infer_results(self):
return self._infer_results
def get_cost_op(self):
"""R
"""
......
......@@ -59,7 +59,7 @@ class SingleTrainer(TranspileTrainer):
def dataloader_train(self, context):
self._exe.run(fluid.default_startup_program())
reader = self._get_dataloader()
reader = self._get_dataloader("TRAIN")
epochs = envs.get_global_env("train.epochs")
program = fluid.compiler.CompiledProgram(
......@@ -95,13 +95,14 @@ class SingleTrainer(TranspileTrainer):
batch_id += 1
except fluid.core.EOFException:
reader.reset()
self.save(epoch, "train", is_fleet=False)
context['status'] = 'infer_pass'
def dataset_train(self, context):
# run startup program at once
self._exe.run(fluid.default_startup_program())
dataset = self._get_dataset()
dataset = self._get_dataset("TRAIN")
epochs = envs.get_global_env("train.epochs")
for i in range(epochs):
......@@ -109,11 +110,54 @@ class SingleTrainer(TranspileTrainer):
dataset=dataset,
fetch_list=self.fetch_vars,
fetch_info=self.fetch_alias,
print_period=self.fetch_period)
print_period=1,
debug=True)
self.save(i, "train", is_fleet=False)
context['status'] = 'infer_pass'
def infer(self, context):
infer_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model.infer_net()
reader = self._get_dataloader("Evaluate")
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_infer_results().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 2 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'terminal_pass'
def terminal(self, context):
......
......@@ -36,28 +36,37 @@ class TranspileTrainer(Trainer):
def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
def _get_dataloader(self):
namespace = "train.reader"
dataloader = self.model._data_loader
def _get_dataloader(self, state):
if state == "TRAIN":
dataloader = self.model._data_loader
namespace = "train.reader"
else:
dataloader = self.model._infer_data_loader
namespace = "evaluate.reader"
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
reader = dataloader_instance.dataloader(reader_class, "TRAIN", self._config_yaml)
reader = dataloader_instance.dataloader(reader_class, state, self._config_yaml)
dataloader.set_sample_generator(reader, batch_size)
return dataloader
def _get_dataset(self):
namespace = "train.reader"
def _get_dataset(self, state):
if state == "TRAIN":
inputs = self.model.get_inputs()
namespace = "train.reader"
train_data_path = envs.get_global_env("train_data_path", None, namespace)
else:
inputs = self.model.get_infer_inputs()
namespace = "evaluate.reader"
train_data_path = envs.get_global_env("test_data_path", None, namespace)
inputs = self.model.get_inputs()
threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state, self._config_yaml)
if train_data_path.startswith("fleetrec::"):
package_base = envs.get_runtime_environ("PACKAGE_BASE")
......@@ -92,13 +101,13 @@ class TranspileTrainer(Trainer):
if not need_save(epoch_id, save_interval, False):
return
print("save inference model is not supported now.")
return
# print("save inference model is not supported now.")
# return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames]
fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None, namespace)
assert dirname is not None
......@@ -129,6 +138,7 @@ class TranspileTrainer(Trainer):
save_persistables()
save_inference_model()
def instance(self, context):
models = envs.get_global_env("train.model.models")
......
......@@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ
def dataloader(readerclass, train, yaml_file):
namespace = "train.reader"
if train == "TRAIN":
reader_name = "TrainReader"
namespace = "train.reader"
data_path = get_global_env("train_data_path", None, namespace)
else:
reader_name = "EvaluateReader"
namespace = "evaluate.reader"
data_path = get_global_env("test_data_path", None, namespace)
if data_path.startswith("fleetrec::"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册