提交 986c7679 编写于 作者: M malin10

Merge branch 'yaml1' of https://github.com/xjqbest/PaddleRec into modify_yaml

...@@ -26,6 +26,7 @@ trainers = {} ...@@ -26,6 +26,7 @@ trainers = {}
def trainer_registry(): def trainer_registry():
trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py") trainers["SingleTrainer"] = os.path.join(trainer_abs, "single_trainer.py")
trainers["SingleInfer"] = os.path.join(trainer_abs, "single_infer.py")
trainers["ClusterTrainer"] = os.path.join(trainer_abs, trainers["ClusterTrainer"] = os.path.join(trainer_abs,
"cluster_trainer.py") "cluster_trainer.py")
trainers["CtrCodingTrainer"] = os.path.join(trainer_abs, trainers["CtrCodingTrainer"] = os.path.join(trainer_abs,
......
...@@ -47,6 +47,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -47,6 +47,7 @@ class SingleTrainer(TranspileTrainer):
self._dataset = {} self._dataset = {}
envs.set_global_envs(self._config) envs.set_global_envs(self._config)
envs.update_workspace() envs.update_workspace()
self._runner_name = envs.get_global_env("mode")
def processor_register(self): def processor_register(self):
self.regist_context_processor('uninit', self.instance) self.regist_context_processor('uninit', self.instance)
...@@ -90,12 +91,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -90,12 +91,9 @@ class SingleTrainer(TranspileTrainer):
for x in os.listdir(train_data_path) for x in os.listdir(train_data_path)
] ]
dataset.set_filelist(file_list) dataset.set_filelist(file_list)
for model_dict in self._env["executor"]: for model_dict in self._env["phase"]:
if model_dict["dataset_name"] == dataset_name: if model_dict["dataset_name"] == dataset_name:
model = self._model[model_dict["name"]][3] model = self._model[model_dict["name"]][3]
if model_dict["is_infer"]:
inputs = model._infer_data_var
else:
inputs = model._data_var inputs = model._data_var
dataset.set_use_var(inputs) dataset.set_use_var(inputs)
break break
...@@ -110,11 +108,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -110,11 +108,14 @@ class SingleTrainer(TranspileTrainer):
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None: if sparse_slots is None and dense_slots is None:
reader = dataloader_instance.dataloader_by_name(reader_class, dataset_name, self._config_yaml) reader = dataloader_instance.dataloader_by_name(
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader") reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class,
"TrainReader")
reader_ins = reader_class(self._config_yaml) reader_ins = reader_class(self._config_yaml)
else: else:
reader = dataloader_instance.slotdataloader_by_name("", dataset_name, self._config_yaml) reader = dataloader_instance.slotdataloader_by_name(
"", dataset_name, self._config_yaml)
reader_ins = SlotReader(self._config_yaml) reader_ins = SlotReader(self._config_yaml)
if hasattr(reader_ins, 'generate_batch_from_trainfiles'): if hasattr(reader_ins, 'generate_batch_from_trainfiles'):
dataloader.set_sample_list_generator(reader) dataloader.set_sample_list_generator(reader)
...@@ -122,7 +123,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -122,7 +123,6 @@ class SingleTrainer(TranspileTrainer):
dataloader.set_sample_generator(reader, batch_size) dataloader.set_sample_generator(reader, batch_size)
return dataloader return dataloader
def _create_dataset(self, dataset_name): def _create_dataset(self, dataset_name):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots")
...@@ -131,7 +131,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -131,7 +131,8 @@ class SingleTrainer(TranspileTrainer):
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
type_name = envs.get_global_env(name + "type") type_name = envs.get_global_env(name + "type")
if envs.get_platform() != "LINUX": if envs.get_platform() != "LINUX":
print("platform ", envs.get_platform(), " change reader to DataLoader") print("platform ", envs.get_platform(),
" change reader to DataLoader")
type_name = "DataLoader" type_name = "DataLoader"
padding = 0 padding = 0
...@@ -140,9 +141,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -140,9 +141,8 @@ class SingleTrainer(TranspileTrainer):
else: else:
return self._get_dataset(dataset_name) return self._get_dataset(dataset_name)
def init(self, context): def init(self, context):
for model_dict in self._env["executor"]: for model_dict in self._env["phase"]:
self._model[model_dict["name"]] = [None] * 5 self._model[model_dict["name"]] = [None] * 5
train_program = fluid.Program() train_program = fluid.Program()
startup_program = fluid.Program() startup_program = fluid.Program()
...@@ -161,33 +161,23 @@ class SingleTrainer(TranspileTrainer): ...@@ -161,33 +161,23 @@ class SingleTrainer(TranspileTrainer):
envs.path_adapter(self._env["workspace"])) envs.path_adapter(self._env["workspace"]))
model = envs.lazy_instance_by_fliename( model = envs.lazy_instance_by_fliename(
model_path, "Model")(self._env) model_path, "Model")(self._env)
is_infer = model_dict.get("is_infer", False)
if is_infer:
model._infer_data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
else:
model._data_var = model.input_data( model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"]) dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name + if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader": ".type") == "DataLoader":
model._init_dataloader(is_infer=is_infer) model._init_dataloader(is_infer=False)
self._get_dataloader(dataset_name, self._get_dataloader(dataset_name,
model._data_loader) model._data_loader)
if is_infer:
model.net(model._infer_data_var, True)
else:
model.net(model._data_var, False) model.net(model._data_var, False)
optimizer = model._build_optimizer(opt_name, opt_lr, optimizer = model._build_optimizer(opt_name, opt_lr,
opt_strategy) opt_strategy)
optimizer.minimize(model._cost) optimizer.minimize(model._cost)
model_dict["is_infer"] = is_infer
self._model[model_dict["name"]][0] = train_program self._model[model_dict["name"]][0] = train_program
self._model[model_dict["name"]][1] = startup_program self._model[model_dict["name"]][1] = startup_program
self._model[model_dict["name"]][2] = scope self._model[model_dict["name"]][2] = scope
self._model[model_dict["name"]][3] = model self._model[model_dict["name"]][3] = model
self._model[model_dict["name"]][4] = train_program.clone() self._model[model_dict["name"]][4] = train_program.clone()
for dataset in self._env["dataset"]: for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset[ self._dataset[dataset["name"]] = self._create_dataset(dataset[
...@@ -196,7 +186,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -196,7 +186,7 @@ class SingleTrainer(TranspileTrainer):
context['status'] = 'startup_pass' context['status'] = 'startup_pass'
def startup(self, context): def startup(self, context):
for model_dict in self._env["executor"]: for model_dict in self._env["phase"]:
with fluid.scope_guard(self._model[model_dict["name"]][2]): with fluid.scope_guard(self._model[model_dict["name"]][2]):
self._exe.run(self._model[model_dict["name"]][1]) self._exe.run(self._model[model_dict["name"]][1])
context['status'] = 'train_pass' context['status'] = 'train_pass'
...@@ -204,13 +194,13 @@ class SingleTrainer(TranspileTrainer): ...@@ -204,13 +194,13 @@ class SingleTrainer(TranspileTrainer):
def executor_train(self, context): def executor_train(self, context):
epochs = int(self._env["epochs"]) epochs = int(self._env["epochs"])
for j in range(epochs): for j in range(epochs):
for model_dict in self._env["executor"]: for model_dict in self._env["phase"]:
if j == 0: if j == 0:
with fluid.scope_guard(self._model[model_dict["name"]][2]): with fluid.scope_guard(self._model[model_dict["name"]][2]):
train_prog = self._model[model_dict["name"]][0] train_prog = self._model[model_dict["name"]][0]
startup_prog = self._model[model_dict["name"]][1] startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
self.load(j) self.load()
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
name = "dataset." + reader_name + "." name = "dataset." + reader_name + "."
begin_time = time.time() begin_time = time.time()
...@@ -235,9 +225,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -235,9 +225,6 @@ class SingleTrainer(TranspileTrainer):
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
...@@ -246,14 +233,6 @@ class SingleTrainer(TranspileTrainer): ...@@ -246,14 +233,6 @@ class SingleTrainer(TranspileTrainer):
program = self._model[model_name][0] program = self._model[model_name][0]
reader = self._dataset[reader_name] reader = self._dataset[reader_name]
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
if model_dict["is_infer"]:
self._exe.infer_from_dataset(
program=program,
dataset=reader,
fetch_list=fetch_vars,
fetch_info=fetch_alias,
print_period=fetch_period)
else:
self._exe.train_from_dataset( self._exe.train_from_dataset(
program=program, program=program,
dataset=reader, dataset=reader,
...@@ -266,15 +245,11 @@ class SingleTrainer(TranspileTrainer): ...@@ -266,15 +245,11 @@ class SingleTrainer(TranspileTrainer):
model_name = model_dict["name"] model_name = model_dict["name"]
model_class = self._model[model_name][3] model_class = self._model[model_name][3]
program = self._model[model_name][0].clone() program = self._model[model_name][0].clone()
if not model_dict["is_infer"]: program = fluid.compiler.CompiledProgram(program).with_data_parallel(
program = fluid.compiler.CompiledProgram( loss_name=model_class.get_avg_cost().name)
program).with_data_parallel(loss_name=model_class.get_avg_cost().name)
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = 20
if model_dict["is_infer"]:
metrics = model_class.get_infer_results()
else:
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
...@@ -283,7 +258,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -283,7 +258,7 @@ class SingleTrainer(TranspileTrainer):
metrics_format = [] metrics_format = []
fetch_period = 20 fetch_period = 20
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in model_class.get_metrics().items(): for name, var in metrics.items():
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name)) metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
...@@ -310,9 +285,11 @@ class SingleTrainer(TranspileTrainer): ...@@ -310,9 +285,11 @@ class SingleTrainer(TranspileTrainer):
context['is_exit'] = True context['is_exit'] = True
def load(self, is_fleet=False): def load(self, is_fleet=False):
dirname = envs.get_global_env("epoch.init_model_path", None) dirname = envs.get_global_env(
"runner." + self._runner_name + ".init_model_path", None)
if dirname is None: if dirname is None:
return return
print("going to load ", dirname)
if is_fleet: if is_fleet:
fleet.load_persistables(self._exe, dirname) fleet.load_persistables(self._exe, dirname)
else: else:
...@@ -328,19 +305,22 @@ class SingleTrainer(TranspileTrainer): ...@@ -328,19 +305,22 @@ class SingleTrainer(TranspileTrainer):
return epoch_id % epoch_interval == 0 return epoch_id % epoch_interval == 0
def save_inference_model(): def save_inference_model():
name = "runner." + self._runner_name + "."
save_interval = int( save_interval = int(
envs.get_global_env("epoch.save_inference_interval", -1)) envs.get_global_env(name + "save_inference_interval", -1))
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) feed_varnames = envs.get_global_env(
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None) name + "save_inference_feed_varnames", None)
if feed_varnames is None or fetch_varnames is None: fetch_varnames = envs.get_global_env(
name + "save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
return return
fetch_vars = [ fetch_vars = [
fluid.default_main_program().global_block().vars[varname] fluid.default_main_program().global_block().vars[varname]
for varname in fetch_varnames for varname in fetch_varnames
] ]
dirname = envs.get_global_env("epoch.save_inference_path", None) dirname = envs.get_global_env(name + "save_inference_path", None)
assert dirname is not None assert dirname is not None
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
...@@ -353,12 +333,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -353,12 +333,14 @@ class SingleTrainer(TranspileTrainer):
fetch_vars, self._exe) fetch_vars, self._exe)
def save_persistables(): def save_persistables():
name = "runner." + self._runner_name + "."
save_interval = int( save_interval = int(
envs.get_global_env("epoch.save_checkpoint_interval", -1)) envs.get_global_env(name + "save_checkpoint_interval", -1))
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None) dirname = envs.get_global_env(name + "save_checkpoint_path", None)
assert dirname is not None if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_persistables(self._exe, dirname) fleet.save_persistables(self._exe, dirname)
......
...@@ -68,7 +68,8 @@ def set_global_envs(envs): ...@@ -68,7 +68,8 @@ def set_global_envs(envs):
nests = copy.deepcopy(namespace_nests) nests = copy.deepcopy(namespace_nests)
nests.append(k) nests.append(k)
fatten_env_namespace(nests, v) fatten_env_namespace(nests, v)
elif (k == "dataset" or k == "executor") and isinstance(v, list): elif (k == "dataset" or k == "phase" or
k == "runner") and isinstance(v, list):
for i in v: for i in v:
if i.get("name") is None: if i.get("name") is None:
raise ValueError("name must be in dataset list ", v) raise ValueError("name must be in dataset list ", v)
......
...@@ -21,12 +21,18 @@ workspace: "paddlerec.models.rank.dnn" ...@@ -21,12 +21,18 @@ workspace: "paddlerec.models.rank.dnn"
# dataset列表 # dataset列表
dataset: dataset:
- name: dataset_2 # 名字,用来区分不同的dataset - name: dataset_train # 名字,用来区分不同的dataset
batch_size: 2 batch_size: 2
type: DataLoader # 或者QueueDataset type: DataLoader # 或者QueueDataset
data_path: "{workspace}/data/sample_data/train" # 数据路径 data_path: "{workspace}/data/sample_data/train" # 数据路径
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26" sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13" dense_slots: "dense_var:13"
- name: dataset_infer # 名字,用来区分不同的dataset
batch_size: 2
type: DataLoader # 或者QueueDataset
data_path: "{workspace}/data/sample_data/test" # 数据路径
sparse_slots: "click 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26"
dense_slots: "dense_var:13"
# 超参数 # 超参数
hyper_parameters: hyper_parameters:
...@@ -42,10 +48,11 @@ hyper_parameters: ...@@ -42,10 +48,11 @@ hyper_parameters:
dense_input_dim: 13 dense_input_dim: 13
fc_sizes: [512, 256, 128, 32] fc_sizes: [512, 256, 128, 32]
# executor配置 mode: runner1
epoch: # runner配置
name: runner:
trainer_class: single - name: runner1
class: single_train
save_checkpoint_interval: 2 # 保存模型 save_checkpoint_interval: 2 # 保存模型
save_inference_interval: 4 # 保存预测模型 save_inference_interval: 4 # 保存预测模型
save_checkpoint_path: "increment" # 保存模型路径 save_checkpoint_path: "increment" # 保存模型路径
...@@ -53,11 +60,17 @@ epoch: ...@@ -53,11 +60,17 @@ epoch:
#save_inference_feed_varnames: [] # 预测模型feed vars #save_inference_feed_varnames: [] # 预测模型feed vars
#save_inference_fetch_varnames: [] # 预测模型 fetch vars #save_inference_fetch_varnames: [] # 预测模型 fetch vars
#init_model_path: "xxxx" # 加载模型 #init_model_path: "xxxx" # 加载模型
- name: runner2
class: single_infer
init_model_path: "increment/0" # 加载模型
# 执行器,每轮要跑的所有模型 # 执行器,每轮要跑的所有阶段
executor: phase:
- name: train - name: phase1
model: "{workspace}/model.py" # 模型路径 model: "{workspace}/model.py" # 模型路径
dataset_name: dataset_2 # 名字,用来区分不同的阶段 dataset_name: dataset_train # 名字,用来区分不同的阶段
thread_num: 1 # 线程数 thread_num: 1 # 线程数
is_infer: False # 是否是infer # - name: phase2
# model: "{workspace}/model.py" # 模型路径
# dataset_name: dataset_infer # 名字,用来区分不同的阶段
# thread_num: 1 # 线程数
...@@ -77,17 +77,21 @@ class Model(ModelBase): ...@@ -77,17 +77,21 @@ class Model(ModelBase):
self.predict = predict self.predict = predict
cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label_input)
avg_cost = fluid.layers.reduce_mean(cost)
self._cost = avg_cost
auc, batch_auc, _ = fluid.layers.auc(input=self.predict, auc, batch_auc, _ = fluid.layers.auc(input=self.predict,
label=self.label_input, label=self.label_input,
num_thresholds=2**12, num_thresholds=2**12,
slide_steps=20) slide_steps=20)
if is_infer:
self._infer_results["AUC"] = auc
self._infer_results["BATCH_AUC"] = batch_auc
return
self._metrics["AUC"] = auc self._metrics["AUC"] = auc
self._metrics["BATCH_AUC"] = batch_auc self._metrics["BATCH_AUC"] = batch_auc
cost = fluid.layers.cross_entropy(
input=self.predict, label=self.label_input)
avg_cost = fluid.layers.reduce_mean(cost)
self._cost = avg_cost
def optimizer(self): def optimizer(self):
optimizer = fluid.optimizer.Adam(self.learning_rate, lazy_mode=True) optimizer = fluid.optimizer.Adam(self.learning_rate, lazy_mode=True)
......
...@@ -18,7 +18,7 @@ import subprocess ...@@ -18,7 +18,7 @@ import subprocess
import argparse import argparse
import tempfile import tempfile
import yaml import yaml
import copy
from paddlerec.core.factory import TrainerFactory from paddlerec.core.factory import TrainerFactory
from paddlerec.core.utils import envs from paddlerec.core.utils import envs
from paddlerec.core.utils import util from paddlerec.core.utils import util
...@@ -27,8 +27,8 @@ engines = {} ...@@ -27,8 +27,8 @@ engines = {}
device = ["CPU", "GPU"] device = ["CPU", "GPU"]
clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"] clusters = ["SINGLE", "LOCAL_CLUSTER", "CLUSTER"]
engine_choices = [ engine_choices = [
"SINGLE", "LOCAL_CLUSTER", "CLUSTER", "TDM_SINGLE", "TDM_LOCAL_CLUSTER", "SINGLE_TRAIN", "LOCAL_CLUSTER", "CLUSTER", "TDM_SINGLE",
"TDM_CLUSTER" "TDM_LOCAL_CLUSTER", "TDM_CLUSTER", "SINGLE_INFER"
] ]
custom_model = ['TDM'] custom_model = ['TDM']
model_name = "" model_name = ""
...@@ -38,7 +38,8 @@ def engine_registry(): ...@@ -38,7 +38,8 @@ def engine_registry():
engines["TRANSPILER"] = {} engines["TRANSPILER"] = {}
engines["PSLIB"] = {} engines["PSLIB"] = {}
engines["TRANSPILER"]["SINGLE"] = single_engine engines["TRANSPILER"]["SINGLE_TRAIN"] = single_train_engine
engines["TRANSPILER"]["SINGLE_INFER"] = single_infer_engine
engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine engines["TRANSPILER"]["LOCAL_CLUSTER"] = local_cluster_engine
engines["TRANSPILER"]["CLUSTER"] = cluster_engine engines["TRANSPILER"]["CLUSTER"] = cluster_engine
engines["PSLIB"]["SINGLE"] = local_mpi_engine engines["PSLIB"]["SINGLE"] = local_mpi_engine
...@@ -51,7 +52,6 @@ def get_inters_from_yaml(file, filters): ...@@ -51,7 +52,6 @@ def get_inters_from_yaml(file, filters):
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader) _envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
flattens = envs.flatten_environs(_envs) flattens = envs.flatten_environs(_envs)
inters = {} inters = {}
for k, v in flattens.items(): for k, v in flattens.items():
for f in filters: for f in filters:
...@@ -60,15 +60,50 @@ def get_inters_from_yaml(file, filters): ...@@ -60,15 +60,50 @@ def get_inters_from_yaml(file, filters):
return inters return inters
def get_all_inters_from_yaml(file, filters):
with open(file, 'r') as rb:
_envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
all_flattens = {}
def fatten_env_namespace(namespace_nests, local_envs):
for k, v in local_envs.items():
if isinstance(v, dict):
nests = copy.deepcopy(namespace_nests)
nests.append(k)
fatten_env_namespace(nests, v)
elif (k == "dataset" or k == "phase" or
k == "runner") and isinstance(v, list):
for i in v:
if i.get("name") is None:
raise ValueError("name must be in dataset list ", v)
nests = copy.deepcopy(namespace_nests)
nests.append(k)
nests.append(i["name"])
fatten_env_namespace(nests, i)
else:
global_k = ".".join(namespace_nests + [k])
all_flattens[global_k] = v
fatten_env_namespace([], _envs)
ret = {}
for k, v in all_flattens.items():
for f in filters:
if k.startswith(f):
ret[k] = v
return ret
def get_engine(args): def get_engine(args):
transpiler = get_transpiler() transpiler = get_transpiler()
run_extras = get_inters_from_yaml(args.model, ["train.", "epoch."]) with open(args.model, 'r') as rb:
envs = yaml.load(rb.read(), Loader=yaml.FullLoader)
run_extras = get_all_inters_from_yaml(args.model, ["train.", "runner."])
engine = run_extras.get("train.engine", None) engine = run_extras.get("train.engine", None)
if engine is None: if engine is None:
engine = run_extras.get("epoch.trainer_class", None) engine = run_extras.get("runner." + envs["mode"] + ".class", None)
if engine is None: if engine is None:
engine = "single" engine = "single_train"
engine = engine.upper() engine = engine.upper()
if engine not in engine_choices: if engine not in engine_choices:
raise ValueError("train.engin can not be chosen in {}".format( raise ValueError("train.engin can not be chosen in {}".format(
...@@ -120,15 +155,27 @@ def get_trainer_prefix(args): ...@@ -120,15 +155,27 @@ def get_trainer_prefix(args):
return "" return ""
def single_engine(args): def single_train_engine(args):
trainer = get_trainer_prefix(args) + "SingleTrainer" trainer = get_trainer_prefix(args) + "SingleTrainer"
single_envs = {} single_envs = {}
single_envs["train.trainer.trainer"] = trainer single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2" single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single" single_envs["train.trainer.engine"] = "single_train"
single_envs["train.trainer.platform"] = envs.get_platform() single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model)) print("use {} engine to run model: {}".format(trainer, args.model))
set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model)
return trainer
def single_infer_engine(args):
trainer = get_trainer_prefix(args) + "SingleInfer"
single_envs = {}
single_envs["train.trainer.trainer"] = trainer
single_envs["train.trainer.threads"] = "2"
single_envs["train.trainer.engine"] = "single_infer"
single_envs["train.trainer.platform"] = envs.get_platform()
print("use {} engine to run model: {}".format(trainer, args.model))
set_runtime_envs(single_envs, args.model) set_runtime_envs(single_envs, args.model)
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册