提交 ca1c4695 编写于 作者: X xjqbest

fix

上级 07bd7092
...@@ -35,7 +35,6 @@ class Reader(dg.MultiSlotDataGenerator): ...@@ -35,7 +35,6 @@ class Reader(dg.MultiSlotDataGenerator):
else: else:
raise ValueError("reader config only support yaml") raise ValueError("reader config only support yaml")
@abc.abstractmethod @abc.abstractmethod
def init(self): def init(self):
"""init""" """init"""
...@@ -56,8 +55,6 @@ class SlotReader(dg.MultiSlotDataGenerator): ...@@ -56,8 +55,6 @@ class SlotReader(dg.MultiSlotDataGenerator):
_config = yaml.load(rb.read(), Loader=yaml.FullLoader) _config = yaml.load(rb.read(), Loader=yaml.FullLoader)
else: else:
raise ValueError("reader config only support yaml") raise ValueError("reader config only support yaml")
#envs.set_global_envs(_config)
#envs.update_workspace()
def init(self, sparse_slots, dense_slots, padding=0): def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul from operator import mul
......
...@@ -69,13 +69,14 @@ class SingleTrainer(TranspileTrainer): ...@@ -69,13 +69,14 @@ class SingleTrainer(TranspileTrainer):
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None: if sparse_slots is None and dense_slots is None:
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, "TRAIN", self._config_yaml) pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else: else:
if sparse_slots is None: if sparse_slots is None:
sparse_slots = "#" sparse_slots = "#"
if dense_slots is None: if dense_slots is None:
dense_slots = "#" dense_slots = "#"
padding = envs.get_global_env(name +"padding", 0) padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format( pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \ reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding)) sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
...@@ -145,19 +146,29 @@ class SingleTrainer(TranspileTrainer): ...@@ -145,19 +146,29 @@ class SingleTrainer(TranspileTrainer):
scope = fluid.Scope() scope = fluid.Scope()
dataset_name = model_dict["dataset_name"] dataset_name = model_dict["dataset_name"]
opt_name = envs.get_global_env("hyper_parameters.optimizer.class") opt_name = envs.get_global_env("hyper_parameters.optimizer.class")
opt_lr = envs.get_global_env("hyper_parameters.optimizer.learning_rate") opt_lr = envs.get_global_env(
opt_strategy = envs.get_global_env("hyper_parameters.optimizer.strategy") "hyper_parameters.optimizer.learning_rate")
opt_strategy = envs.get_global_env(
"hyper_parameters.optimizer.strategy")
with fluid.program_guard(train_program, startup_program): with fluid.program_guard(train_program, startup_program):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
model_path = model_dict["model"].replace("{workspace}", envs.path_adapter(self._env["workspace"])) model_path = model_dict["model"].replace(
model = envs.lazy_instance_by_fliename(model_path, "Model")(self._env) "{workspace}",
model._data_var = model.input_data(dataset_name=model_dict["dataset_name"]) envs.path_adapter(self._env["workspace"]))
if envs.get_global_env("dataset." + dataset_name + ".type") == "DataLoader": model = envs.lazy_instance_by_fliename(
model_path, "Model")(self._env)
model._data_var = model.input_data(
dataset_name=model_dict["dataset_name"])
if envs.get_global_env("dataset." + dataset_name +
".type") == "DataLoader":
model._init_dataloader() model._init_dataloader()
self._get_dataloader(dataset_name, model._data_loader) self._get_dataloader(dataset_name,
model.net(model._data_var, is_infer=model_dict.get("is_infer", False)) model._data_loader)
optimizer = model._build_optimizer(opt_name, opt_lr, opt_strategy) model.net(model._data_var,
is_infer=model_dict.get("is_infer", False))
optimizer = model._build_optimizer(opt_name, opt_lr,
opt_strategy)
optimizer.minimize(model._cost) optimizer.minimize(model._cost)
self._model[model_dict["name"]][0] = train_program self._model[model_dict["name"]][0] = train_program
self._model[model_dict["name"]][1] = startup_program self._model[model_dict["name"]][1] = startup_program
...@@ -167,7 +178,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -167,7 +178,8 @@ class SingleTrainer(TranspileTrainer):
for dataset in self._env["dataset"]: for dataset in self._env["dataset"]:
if dataset["type"] != "DataLoader": if dataset["type"] != "DataLoader":
self._dataset[dataset["name"]] = self._create_dataset(dataset["name"]) self._dataset[dataset["name"]] = self._create_dataset(dataset[
"name"])
context['status'] = 'startup_pass' context['status'] = 'startup_pass'
...@@ -289,7 +301,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -289,7 +301,8 @@ class SingleTrainer(TranspileTrainer):
return epoch_id % epoch_interval == 0 return epoch_id % epoch_interval == 0
def save_inference_model(): def save_inference_model():
save_interval = envs.get_global_env("epoch.save_inference_interval", -1) save_interval = int(
envs.get_global_env("epoch.save_inference_interval", -1)
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
...@@ -313,7 +326,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -313,7 +326,8 @@ class SingleTrainer(TranspileTrainer):
fetch_vars, self._exe) fetch_vars, self._exe)
def save_persistables(): def save_persistables():
save_interval = int(envs.get_global_env("epoch.save_checkpoint_interval", -1)) save_interval = int(
envs.get_global_env("epoch.save_checkpoint_interval", -1))
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None) dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
......
...@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env ...@@ -19,6 +19,7 @@ from paddlerec.core.utils.envs import get_global_env
from paddlerec.core.utils.envs import get_runtime_environ from paddlerec.core.utils.envs import get_runtime_environ
from paddlerec.core.reader import SlotReader from paddlerec.core.reader import SlotReader
def dataloader_by_name(readerclass, dataset_name, yaml_file): def dataloader_by_name(readerclass, dataset_name, yaml_file):
reader_class = lazy_instance_by_fliename(readerclass, "TrainReader") reader_class = lazy_instance_by_fliename(readerclass, "TrainReader")
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
...@@ -30,9 +31,9 @@ def dataloader_by_name(readerclass, dataset_name, yaml_file): ...@@ -30,9 +31,9 @@ def dataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
reader = reader_class(yaml_file) reader = reader_class(yaml_file)
reader.init() reader.init()
def gen_reader(): def gen_reader():
for file in files: for file in files:
with open(file, 'r') as f: with open(file, 'r') as f:
...@@ -67,7 +68,6 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file): ...@@ -67,7 +68,6 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env(name + "sparse_slots") sparse = get_global_env(name + "sparse_slots")
dense = get_global_env(name + "dense_slots") dense = get_global_env(name + "dense_slots")
padding = get_global_env(name + "padding", 0) padding = get_global_env(name + "padding", 0)
...@@ -96,6 +96,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file): ...@@ -96,6 +96,7 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
return gen_batch_reader() return gen_batch_reader()
return gen_reader return gen_reader
def dataloader(readerclass, train, yaml_file): def dataloader(readerclass, train, yaml_file):
if train == "TRAIN": if train == "TRAIN":
reader_name = "TrainReader" reader_name = "TrainReader"
......
...@@ -20,6 +20,7 @@ import sys ...@@ -20,6 +20,7 @@ import sys
global_envs = {} global_envs = {}
def flatten_environs(envs, separator="."): def flatten_environs(envs, separator="."):
flatten_dict = {} flatten_dict = {}
assert isinstance(envs, dict) assert isinstance(envs, dict)
...@@ -81,6 +82,7 @@ def set_global_envs(envs): ...@@ -81,6 +82,7 @@ def set_global_envs(envs):
fatten_env_namespace([], envs) fatten_env_namespace([], envs)
def get_global_env(env_name, default_value=None, namespace=None): def get_global_env(env_name, default_value=None, namespace=None):
""" """
get os environment value get os environment value
......
...@@ -27,9 +27,12 @@ class Model(ModelBase): ...@@ -27,9 +27,12 @@ class Model(ModelBase):
def _init_hyper_parameters(self): def _init_hyper_parameters(self):
self.is_distributed = True if envs.get_trainer( self.is_distributed = True if envs.get_trainer(
) == "CtrTrainer" else False ) == "CtrTrainer" else False
self.sparse_feature_number = envs.get_global_env("hyper_parameters.sparse_feature_number") self.sparse_feature_number = envs.get_global_env(
self.sparse_feature_dim = envs.get_global_env("hyper_parameters.sparse_feature_dim") "hyper_parameters.sparse_feature_number")
self.learning_rate = envs.get_global_env("hyper_parameters.learning_rate") self.sparse_feature_dim = envs.get_global_env(
"hyper_parameters.sparse_feature_dim")
self.learning_rate = envs.get_global_env(
"hyper_parameters.learning_rate")
def net(self, input, is_infer=False): def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:] self.sparse_inputs = self._sparse_data_var[1:]
......
...@@ -69,9 +69,7 @@ def get_engine(args): ...@@ -69,9 +69,7 @@ def get_engine(args):
engine = run_extras.get("epoch.trainer_class", None) engine = run_extras.get("epoch.trainer_class", None)
if engine is None: if engine is None:
engine = "single" engine = "single"
engine = engine.upper() engine = engine.upper()
if engine not in engine_choices: if engine not in engine_choices:
raise ValueError("train.engin can not be chosen in {}".format( raise ValueError("train.engin can not be chosen in {}".format(
engine_choices)) engine_choices))
...@@ -135,6 +133,7 @@ def single_engine(args): ...@@ -135,6 +133,7 @@ def single_engine(args):
trainer = TrainerFactory.create(args.model) trainer = TrainerFactory.create(args.model)
return trainer return trainer
def cluster_engine(args): def cluster_engine(args):
def update_workspace(cluster_envs): def update_workspace(cluster_envs):
workspace = cluster_envs.get("engine_workspace", None) workspace = cluster_envs.get("engine_workspace", None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册