提交 eabfd85d 编写于 作者: X xjqbest

fix

上级 cf7f8499
...@@ -87,9 +87,13 @@ class Model(object): ...@@ -87,9 +87,13 @@ class Model(object):
if dataset_class == "DataLoader": if dataset_class == "DataLoader":
self._init_dataloader() self._init_dataloader()
def _init_dataloader(self): def _init_dataloader(self, is_infer=False):
if is_infer:
data = self._infer_data_var
else:
data = self._data_var
self._data_loader = fluid.io.DataLoader.from_generator( self._data_loader = fluid.io.DataLoader.from_generator(
feed_list=self._data_var, feed_list=data,
capacity=64, capacity=64,
use_double_buffer=False, use_double_buffer=False,
iterable=False) iterable=False)
......
...@@ -93,7 +93,10 @@ class SingleTrainer(TranspileTrainer): ...@@ -93,7 +93,10 @@ class SingleTrainer(TranspileTrainer):
for model_dict in self._env["executor"]: for model_dict in self._env["executor"]:
if model_dict["dataset_name"] == dataset_name: if model_dict["dataset_name"] == dataset_name:
model = self._model[model_dict["name"]][3] model = self._model[model_dict["name"]][3]
inputs = model.get_inputs() if model_dict["is_infer"]:
inputs = model._infer_data_var
else:
inputs = model._data_var
dataset.set_use_var(inputs) dataset.set_use_var(inputs)
break break
return dataset return dataset
...@@ -158,24 +161,33 @@ class SingleTrainer(TranspileTrainer): ...@@ -158,24 +161,33 @@ class SingleTrainer(TranspileTrainer):
envs.path_adapter(self._env["workspace"])) envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename( model = envs.lazy_instance_by_fliename(
model_path, "Model")(self._env) model_path, "Model")(self._env)
is_infer = model_dict.get("is_infer", False)
if is_infer:
model._infer_data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
else:
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader": ".type") == "DataLoader":
model._init_dataloader() model._init_dataloader(is_infer=is_infer)
self._get_dataloader(dataset_name, self._get_dataloader(dataset_name,
model._data_loader) model._data_loader)
model.net(model._data_var, if is_infer:
is_infer=model_dict.get("is_infer", False)) model.net(model._infer_data_var, True)
else:
model.net(model._data_var, False)
optimizer = model._build_optimizer(opt_name, opt_lr, optimizer = model._build_optimizer(opt_name, opt_lr,
opt_strategy) opt_strategy)
optimizer.minimize(model._cost) optimizer.minimize(model._cost)
model_dict["is_infer"] = is_infer
self._model[model_dict["name"]][0] = train_program self._model[model_dict["name"]][0] = train_program
self._model[model_dict["name"]][1] = startup_program self._model[model_dict["name"]][1] = startup_program
self._model[model_dict["name"]][2] = scope self._model[model_dict["name"]][2] = scope
self._model[model_dict["name"]][3] = model self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone() self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]: for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset[ self._dataset[dataset["name"]] = self._create_dataset(dataset[
...@@ -223,6 +235,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -223,6 +235,9 @@ class SingleTrainer(TranspileTrainer):
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
...@@ -231,6 +246,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -231,6 +246,14 @@ class SingleTrainer(TranspileTrainer):
program = self._model[model_name][0] program = self._model[model_name][0]
reader = self._dataset[reader_name] reader = self._dataset[reader_name]
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
if model_dict["is_infer"]:
self._exe.infer_from_dataset(
program=program,
dataset=reader,
fetch_list=fetch_vars,
fetch_info=fetch_alias,
print_period=fetch_period)
else:
self._exe.train_from_dataset( self._exe.train_from_dataset(
program=program, program=program,
dataset=reader, dataset=reader,
...@@ -243,11 +266,15 @@ class SingleTrainer(TranspileTrainer): ...@@ -243,11 +266,15 @@ class SingleTrainer(TranspileTrainer):
model_name = model_dict["name"] model_name = model_dict["name"]
model_class = self._model[model_name][3] model_class = self._model[model_name][3]
program = self._model[model_name][0].clone() program = self._model[model_name][0].clone()
if not model_dict["is_infer"]:
program = fluid.compiler.CompiledProgram( program = fluid.compiler.CompiledProgram(
program).with_data_parallel(loss_name=model_class.get_avg_cost().name) program).with_data_parallel(loss_name=model_class.get_avg_cost().name)
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册