提交 95f2364b 编写于 作者: X xjqbest

fix

上级 fc94d505
......@@ -70,7 +70,6 @@ class SingleTrainer(TranspileTrainer):
if sparse_slots is None and dense_slots is None:
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "fake", self._config_yaml)
else:
if sparse_slots is None:
sparse_slots = "#"
......@@ -98,7 +97,7 @@ class SingleTrainer(TranspileTrainer):
break
return dataset
def _get_dataloader(self, dataset_name):
def _get_dataloader(self, dataset_name, dataloader):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
......@@ -106,9 +105,7 @@ class SingleTrainer(TranspileTrainer):
batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env("data_convertor")
abs_dir = os.path.dirname(os.path.abspath(__file__))
#reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None:
#reader_class = envs.get_global_env("class")
reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader")
reader_ins = reader_class(self._config_yaml)
......@@ -181,10 +178,10 @@ class SingleTrainer(TranspileTrainer):
model_path = model_dict["model"].replace("{workspace}", envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(self._env)
model._data_var = model.input_data(dataset_name=model_dict["dataset_name"])
#model._init_slots(name=model_dict["name"])
if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader":
model._init_dataloader()
model.net(model._data_var)
self._get_dataloader(dataset_name, model._data_loader)
model.net(model._data_var, is_infer=model_dict["is_infer"])
optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy)
optimizer.minimize(model._cost)
self._model[model_dict["name"]][0] = train_program
......@@ -215,6 +212,8 @@ class SingleTrainer(TranspileTrainer):
self._executor_dataloader_train(model_dict)
else:
self._executor_dataset_train(model_dict)
with fluid.scope_guard(self._model[model_name][2]):
self.save(self, j)
end_time = time.time()
seconds = end_time - begin_time
print("epoch {} done, time elasped: {}".format(j, seconds))
......@@ -270,7 +269,6 @@ class SingleTrainer(TranspileTrainer):
batch_id = 0
scope = self._model[model_name][2]
program = self._model[model_name][0]
#print(metrics_varnames)
with fluid.scope_guard(scope):
try:
while True:
......@@ -287,3 +285,53 @@ class SingleTrainer(TranspileTrainer):
def terminal(self, context):
context['is_exit'] = True
def save(self, epoch_id, is_fleet=False):
def need_save(epoch_id, epoch_interval, is_last=False):
if is_last:
return True
if epoch_id == -1:
return False
return epoch_id % epoch_interval == 0
def save_inference_model():
save_interval = envs.get_global_env("epoch.save_inference_interval", -1)
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None:
return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames
]
dirname = envs.get_global_env("epoch.save_inference_path", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_inference_model(self._exe, dirname, feed_varnames,
fetch_vars)
else:
fluid.io.save_inference_model(dirname, feed_varnames,
fetch_vars, self._exe)
self.inference_models.append((epoch_id, dirname))
def save_persistables():
save_interval = envs.get_global_env("epoch.save_checkpoint_interval", -1)
if not need_save(epoch_id, save_interval, False):
return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_persistables(self._exe, dirname)
else:
fluid.io.save_persistables(self._exe, dirname)
self.increment_models.append((epoch_id, dirname))
save_persistables()
save_inference_model()
......@@ -23,10 +23,6 @@ def dataloader_by_name(readerclass, dataset_name, yaml_file):
reader_class = lazy_instance_by_fliename(readerclass, "TrainReader")
name = "dataset." + dataset_name + "."
data_path = get_global_env(name + "data_path")
#else:
# reader_name = "SlotReader"
# namespace = "evaluate.reader"
# data_path = get_global_env("test_data_path", None, namespace)
if data_path.startswith("paddlerec::"):
package_base = get_runtime_environ("PACKAGE_BASE")
......
......@@ -51,3 +51,4 @@ executor:
model: "{workspace}/model.py"
dataset_name: dataset_2
thread_num: 1
is_infer: False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册