提交 eabfd85d 编写于 作者: X xjqbest

fix

上级 cf7f8499
......@@ -87,9 +87,13 @@ class Model(object):
if dataset_class == "DataLoader":
self._init_dataloader()
def _init_dataloader(self):
def _init_dataloader(self, is_infer=False):
if is_infer:
data = self._infer_data_var
else:
data = self._data_var
self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var,
feed_list=data,
capacity=64,
use_double_buffer=False,
iterable=False)
......
......@@ -93,7 +93,10 @@ class SingleTrainer(TranspileTrainer):
for model_dict in self._env["executor"]:
if model_dict["dataset_name"] == dataset_name:
model = self._model[model_dict["name"]][3]
inputs = model.get_inputs()
if model_dict["is_infer"]:
inputs = model._infer_data_var
else:
inputs = model._data_var
dataset.set_use_var(inputs)
break
return dataset
......@@ -158,24 +161,33 @@ class SingleTrainer(TranspileTrainer):
envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename(
model_path, "Model")(self._env)
is_infer = model_dict.get("is_infer", False)
if is_infer:
model._infer_data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
else:
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader":
model._init_dataloader()
model._init_dataloader(is_infer=is_infer)
self._get_dataloader(dataset_name,
model._data_loader)
model.net(model._data_var,
is_infer=model_dict.get("is_infer", False))
if is_infer:
model.net(model._infer_data_var, True)
else:
model.net(model._data_var, False)
optimizer = model._build_optimizer(opt_name, opt_lr,
opt_strategy)
optimizer.minimize(model._cost)
model_dict["is_infer"] = is_infer
self._model[model_dict["name"]][0] = train_program
self._model[model_dict["name"]][1] = startup_program
self._model[model_dict["name"]][2] = scope
self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset[
......@@ -223,6 +235,9 @@ class SingleTrainer(TranspileTrainer):
fetch_vars = []
fetch_alias = []
fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
......@@ -231,6 +246,14 @@ class SingleTrainer(TranspileTrainer):
program = self._model[model_name][0]
reader = self._dataset[reader_name]
with fluid.scope_guard(scope):
if model_dict["is_infer"]:
self._exe.infer_from_dataset(
program=program,
dataset=reader,
fetch_list=fetch_vars,
fetch_info=fetch_alias,
print_period=fetch_period)
else:
self._exe.train_from_dataset(
program=program,
dataset=reader,
......@@ -243,11 +266,15 @@ class SingleTrainer(TranspileTrainer):
model_name = model_dict["name"]
model_class = self._model[model_name][3]
program = self._model[model_name][0].clone()
if not model_dict["is_infer"]:
program = fluid.compiler.CompiledProgram(
program).with_data_parallel(loss_name=model_class.get_avg_cost().name)
fetch_vars = []
fetch_alias = []
fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册