提交 dd378956 编写于 作者: M malin10

add infer

上级 2fed2cdf
...@@ -16,7 +16,10 @@ class Model(object): ...@@ -16,7 +16,10 @@ class Model(object):
self._cost = None self._cost = None
self._metrics = {} self._metrics = {}
self._data_var = [] self._data_var = []
self._infer_data_var = []
self._infer_results = {}
self._data_loader = None self._data_loader = None
self._infer_data_loader = None
self._fetch_interval = 20 self._fetch_interval = 20
self._namespace = "train.model" self._namespace = "train.model"
self._platform = envs.get_platform() self._platform = envs.get_platform()
...@@ -24,6 +27,12 @@ class Model(object): ...@@ -24,6 +27,12 @@ class Model(object):
def get_inputs(self): def get_inputs(self):
return self._data_var return self._data_var
def get_infer_inputs(self):
return self._infer_data_var
def get_infer_results(self):
return self._infer_results
def get_cost_op(self): def get_cost_op(self):
"""R """R
""" """
......
...@@ -59,7 +59,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -59,7 +59,7 @@ class SingleTrainer(TranspileTrainer):
def dataloader_train(self, context): def dataloader_train(self, context):
self._exe.run(fluid.default_startup_program()) self._exe.run(fluid.default_startup_program())
reader = self._get_dataloader() reader = self._get_dataloader("TRAIN")
epochs = envs.get_global_env("train.epochs") epochs = envs.get_global_env("train.epochs")
program = fluid.compiler.CompiledProgram( program = fluid.compiler.CompiledProgram(
...@@ -95,13 +95,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -95,13 +95,14 @@ class SingleTrainer(TranspileTrainer):
batch_id += 1 batch_id += 1
except fluid.core.EOFException: except fluid.core.EOFException:
reader.reset() reader.reset()
self.save(epoch, "train", is_fleet=False)
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
def dataset_train(self, context): def dataset_train(self, context):
# run startup program at once # run startup program at once
self._exe.run(fluid.default_startup_program()) self._exe.run(fluid.default_startup_program())
dataset = self._get_dataset() dataset = self._get_dataset("TRAIN")
epochs = envs.get_global_env("train.epochs") epochs = envs.get_global_env("train.epochs")
for i in range(epochs): for i in range(epochs):
...@@ -109,11 +110,54 @@ class SingleTrainer(TranspileTrainer): ...@@ -109,11 +110,54 @@ class SingleTrainer(TranspileTrainer):
dataset=dataset, dataset=dataset,
fetch_list=self.fetch_vars, fetch_list=self.fetch_vars,
fetch_info=self.fetch_alias, fetch_info=self.fetch_alias,
print_period=self.fetch_period) print_period=1,
debug=True)
self.save(i, "train", is_fleet=False) self.save(i, "train", is_fleet=False)
context['status'] = 'infer_pass' context['status'] = 'infer_pass'
def infer(self, context): def infer(self, context):
infer_program = fluid.Program()
startup_program = fluid.Program()
with fluid.unique_name.guard():
with fluid.program_guard(infer_program, startup_program):
self.model.infer_net()
reader = self._get_dataloader("Evaluate")
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in self.model.get_infer_results().items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
self._exe.run(startup_program)
for (epoch, model_dir) in self.increment_models:
print("Begin to infer epoch {}, model_dir: {}".format(epoch, model_dir))
program = infer_program.clone()
fluid.io.load_persistables(self._exe, model_dir, program)
reader.start()
batch_id = 0
try:
while True:
metrics_rets = self._exe.run(
program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % 2 == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
reader.reset()
context['status'] = 'terminal_pass' context['status'] = 'terminal_pass'
def terminal(self, context): def terminal(self, context):
......
...@@ -36,28 +36,37 @@ class TranspileTrainer(Trainer): ...@@ -36,28 +36,37 @@ class TranspileTrainer(Trainer):
def processor_register(self): def processor_register(self):
print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first") print("Need implement by trainer, `self.regist_context_processor('uninit', self.instance)` must be the first")
def _get_dataloader(self): def _get_dataloader(self, state):
namespace = "train.reader" if state == "TRAIN":
dataloader = self.model._data_loader dataloader = self.model._data_loader
namespace = "train.reader"
else:
dataloader = self.model._infer_data_loader
namespace = "evaluate.reader"
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
reader = dataloader_instance.dataloader(reader_class, "TRAIN", self._config_yaml) reader = dataloader_instance.dataloader(reader_class, state, self._config_yaml)
dataloader.set_sample_generator(reader, batch_size) dataloader.set_sample_generator(reader, batch_size)
return dataloader return dataloader
def _get_dataset(self): def _get_dataset(self, state):
if state == "TRAIN":
inputs = self.model.get_inputs()
namespace = "train.reader" namespace = "train.reader"
train_data_path = envs.get_global_env("train_data_path", None, namespace)
else:
inputs = self.model.get_infer_inputs()
namespace = "evaluate.reader"
train_data_path = envs.get_global_env("test_data_path", None, namespace)
inputs = self.model.get_inputs()
threads = int(envs.get_runtime_environ("train.trainer.threads")) threads = int(envs.get_runtime_environ("train.trainer.threads"))
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) pipe_cmd = "python {} {} {} {}".format(reader, reader_class, state, self._config_yaml)
train_data_path = envs.get_global_env("train_data_path", None, namespace)
if train_data_path.startswith("fleetrec::"): if train_data_path.startswith("fleetrec::"):
package_base = envs.get_runtime_environ("PACKAGE_BASE") package_base = envs.get_runtime_environ("PACKAGE_BASE")
...@@ -93,12 +102,12 @@ class TranspileTrainer(Trainer): ...@@ -93,12 +102,12 @@ class TranspileTrainer(Trainer):
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
print("save inference model is not supported now.") # print("save inference model is not supported now.")
return # return
feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace)
fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace) fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace)
fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] fetch_vars = [fluid.default_main_program().global_block().vars[varname] for varname in fetch_varnames]
dirname = envs.get_global_env("save.inference.dirname", None, namespace) dirname = envs.get_global_env("save.inference.dirname", None, namespace)
assert dirname is not None assert dirname is not None
...@@ -130,6 +139,7 @@ class TranspileTrainer(Trainer): ...@@ -130,6 +139,7 @@ class TranspileTrainer(Trainer):
save_persistables() save_persistables()
save_inference_model() save_inference_model()
def instance(self, context): def instance(self, context):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance_by_fliename(models, "Model") model_class = envs.lazy_instance_by_fliename(models, "Model")
......
...@@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ ...@@ -22,13 +22,13 @@ from fleetrec.core.utils.envs import get_runtime_environ
def dataloader(readerclass, train, yaml_file): def dataloader(readerclass, train, yaml_file):
namespace = "train.reader"
if train == "TRAIN": if train == "TRAIN":
reader_name = "TrainReader" reader_name = "TrainReader"
namespace = "train.reader"
data_path = get_global_env("train_data_path", None, namespace) data_path = get_global_env("train_data_path", None, namespace)
else: else:
reader_name = "EvaluateReader" reader_name = "EvaluateReader"
namespace = "evaluate.reader"
data_path = get_global_env("test_data_path", None, namespace) data_path = get_global_env("test_data_path", None, namespace)
if data_path.startswith("fleetrec::"): if data_path.startswith("fleetrec::"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册