提交 421c293f 编写于 作者: X xjqbest

fix

上级 b4253934
......@@ -137,35 +137,9 @@ class SingleTrainer(TranspileTrainer):
return self._get_dataset(dataset_name)
reader = envs.path_adapter("paddlerec.core.utils") + "/dataset_instance.py"
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
if type_name == "QueueDataset":
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
dataset.set_pipe_command(pipe_cmd)
train_data_path = envs.get_global_env(name + "data_path")
file_list = [
os.path.join(train_data_path, x)
for x in os.listdir(train_data_path)
]
dataset.set_filelist(file_list)
for model_dict in self._env["executor"]:
if model_dict["dataset_name"] == dataset_name:
model = self._model[model_dict["name"]][3]
inputs = model.get_inputs()
dataset.set_use_var(inputs)
break
else:
pass
return dataset
def init(self, context):
for model_dict in self._env["executor"]:
self._model[model_dict["name"]] = [None] * 4
self._model[model_dict["name"]] = [None] * 5
train_program = fluid.Program()
startup_program = fluid.Program()
scope = fluid.Scope()
......@@ -175,6 +149,7 @@ class SingleTrainer(TranspileTrainer):
opt_strategy = envs.get_global_env("hyper_parameters.optimizer.strategy")
with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard():
with fluid.scope_guard(scope):
model_path = model_dict["model"].replace("{workspace}", envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename(model_path, "Model")(self._env)
model._data_var = model.input_data(dataset_name=model_dict["dataset_name"])
......@@ -188,6 +163,7 @@ class SingleTrainer(TranspileTrainer):
self._model[model_dict["name"]][1] = startup_program
self._model[model_dict["name"]][2] = scope
self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader":
......@@ -219,7 +195,7 @@ class SingleTrainer(TranspileTrainer):
else:
self._executor_dataset_train(model_dict)
with fluid.scope_guard(self._model[model_dict["name"]][2]):
train_prog = self._model[model_dict["name"]][0]
train_prog = self._model[model_dict["name"]][4]
startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog):
self.save(j)
......@@ -250,13 +226,13 @@ class SingleTrainer(TranspileTrainer):
fetch_info=fetch_alias,
print_period=fetch_period)
def _executor_dataloader_train(self, model_dict):
reader_name = model_dict["dataset_name"]
model_name = model_dict["name"]
model_class = self._model[model_name][3]
self._model[model_name][0] = fluid.compiler.CompiledProgram(
self._model[model_name][0]).with_data_parallel(loss_name=model_class.get_avg_cost().name)
program = self._model[model_name][0].clone()
program = fluid.compiler.CompiledProgram(
program).with_data_parallel(loss_name=model_class.get_avg_cost().name)
fetch_vars = []
fetch_alias = []
fetch_period = 20
......@@ -266,7 +242,8 @@ class SingleTrainer(TranspileTrainer):
fetch_alias = metrics.keys()
metrics_varnames = []
metrics_format = []
metrics_format.append("{}: {{}}".format("epoch"))
fetch_period = 20
#metrics_format.append("{}: {{}}".format("epoch"))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in model_class.get_metrics().items():
metrics_varnames.append(var.name)
......@@ -277,16 +254,15 @@ class SingleTrainer(TranspileTrainer):
reader.start()
batch_id = 0
scope = self._model[model_name][2]
program = self._model[model_name][0]
with fluid.scope_guard(scope):
try:
while True:
metrics_rets = self._exe.run(program=program,
fetch_list=metrics_varnames)
metrics = [epoch, batch_id]
metrics = [batch_id]#[epoch, batch_id]
metrics.extend(metrics_rets)
if batch_id % self.fetch_period == 0 and batch_id != 0:
if batch_id % fetch_period == 0 and batch_id != 0:
print(metrics_format.format(*metrics))
batch_id += 1
except fluid.core.EOFException:
......
......@@ -21,8 +21,8 @@ workspace: "paddlerec.models.rank.dnn"
dataset:
- name: dataset_2
batch_size: 2
type: QueueDataset
#type: DataLoader
#type: QueueDataset
type: DataLoader
data_path: "{workspace}/data/sample_data/train"
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册