提交 a09255fb 编写于 作者: X xjqbest

fix

上级 37a77dcd
...@@ -59,11 +59,17 @@ class Model(object): ...@@ -59,11 +59,17 @@ class Model(object):
dataset = i dataset = i
break break
name = "dataset." + dataset["name"] + "." name = "dataset." + dataset["name"] + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots is not None or dense_slots is not None: if sparse_slots != "" or dense_slots != "":
sparse_slots = sparse_slots.strip().split(" ") if sparse_slots == "":
dense_slots = dense_slots.strip().split(" ") sparse_slots = []
else:
sparse_slots = sparse_slots.strip().split(" ")
if dense_slots == "":
dense_slots = []
else:
dense_slots = dense_slots.strip().split(" ")
dense_slots_shape = [[ dense_slots_shape = [[
int(j) for j in i.split(":")[1].strip("[]").split(",") int(j) for j in i.split(":")[1].strip("[]").split(",")
] for i in dense_slots] ] for i in dense_slots]
...@@ -151,11 +157,17 @@ class Model(object): ...@@ -151,11 +157,17 @@ class Model(object):
def input_data(self, is_infer=False, **kwargs): def input_data(self, is_infer=False, **kwargs):
name = "dataset." + kwargs.get("dataset_name") + "." name = "dataset." + kwargs.get("dataset_name") + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots is not None or dense_slots is not None: if sparse_slots != "" or dense_slots != "":
sparse_slots = sparse_slots.strip().split(" ") if sparse_slots == "":
dense_slots = dense_slots.strip().split(" ") sparse_slots = []
else:
sparse_slots = sparse_slots.strip().split(" ")
if dense_slots == "":
dense_slots = []
else:
dense_slots = dense_slots.strip().split(" ")
dense_slots_shape = [[ dense_slots_shape = [[
int(j) for j in i.split(":")[1].strip("[]").split(",") int(j) for j in i.split(":")[1].strip("[]").split(",")
] for i in dense_slots] ] for i in dense_slots]
......
...@@ -67,15 +67,14 @@ class SingleInfer(TranspileTrainer): ...@@ -67,15 +67,14 @@ class SingleInfer(TranspileTrainer):
def _get_dataset(self, dataset_name): def _get_dataset(self, dataset_name):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
if sparse_slots is None and dense_slots is None: dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots == "" and dense_slots == "":
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml) "TRAIN", self._config_yaml)
else: else:
...@@ -107,13 +106,13 @@ class SingleInfer(TranspileTrainer): ...@@ -107,13 +106,13 @@ class SingleInfer(TranspileTrainer):
def _get_dataloader(self, dataset_name, dataloader): def _get_dataloader(self, dataset_name, dataloader):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None: sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots == "" and dense_slots == "":
reader = dataloader_instance.dataloader_by_name( reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml) reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, reader_class = envs.lazy_instance_by_fliename(reader_class,
...@@ -228,7 +227,9 @@ class SingleInfer(TranspileTrainer): ...@@ -228,7 +227,9 @@ class SingleInfer(TranspileTrainer):
model_class = self._model[model_name][3] model_class = self._model[model_name][3]
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_infer_results() metrics = model_class.get_infer_results()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
...@@ -251,14 +252,15 @@ class SingleInfer(TranspileTrainer): ...@@ -251,14 +252,15 @@ class SingleInfer(TranspileTrainer):
program = self._model[model_name][0].clone() program = self._model[model_name][0].clone()
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20
metrics = model_class.get_infer_results() metrics = model_class.get_infer_results()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
fetch_alias = metrics.keys() fetch_alias = metrics.keys()
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
fetch_period = 20 fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items(): for name, var in metrics.items():
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
......
...@@ -61,21 +61,20 @@ class SingleTrainer(TranspileTrainer): ...@@ -61,21 +61,20 @@ class SingleTrainer(TranspileTrainer):
def _get_dataset(self, dataset_name): def _get_dataset(self, dataset_name):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py') reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
if sparse_slots is None and dense_slots is None: dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots != "" and dense_slots != "":
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml) "TRAIN", self._config_yaml)
else: else:
if sparse_slots is None: if sparse_slots == "":
sparse_slots = "#" sparse_slots = "#"
if dense_slots is None: if dense_slots == "":
dense_slots = "#" dense_slots = "#"
padding = envs.get_global_env(name + "padding", 0) padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format( pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
...@@ -101,13 +100,13 @@ class SingleTrainer(TranspileTrainer): ...@@ -101,13 +100,13 @@ class SingleTrainer(TranspileTrainer):
def _get_dataloader(self, dataset_name, dataloader): def _get_dataloader(self, dataset_name, dataloader):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter") reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__)) abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None: if sparse_slots == "" and dense_slots == "":
reader = dataloader_instance.dataloader_by_name( reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml) reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class, reader_class = envs.lazy_instance_by_fliename(reader_class,
...@@ -125,8 +124,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -125,8 +124,8 @@ class SingleTrainer(TranspileTrainer):
def _create_dataset(self, dataset_name): def _create_dataset(self, dataset_name):
name = "dataset." + dataset_name + "." name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots") sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots") dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
thread_num = envs.get_global_env(name + "thread_num") thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size") batch_size = envs.get_global_env(name + "batch_size")
type_name = envs.get_global_env(name + "type") type_name = envs.get_global_env(name + "type")
...@@ -225,7 +224,9 @@ class SingleTrainer(TranspileTrainer): ...@@ -225,7 +224,9 @@ class SingleTrainer(TranspileTrainer):
model_class = self._model[model_name][3] model_class = self._model[model_name][3]
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
...@@ -250,14 +251,15 @@ class SingleTrainer(TranspileTrainer): ...@@ -250,14 +251,15 @@ class SingleTrainer(TranspileTrainer):
loss_name=model_class.get_avg_cost().name) loss_name=model_class.get_avg_cost().name)
fetch_vars = [] fetch_vars = []
fetch_alias = [] fetch_alias = []
fetch_period = 20 fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_metrics() metrics = model_class.get_metrics()
if metrics: if metrics:
fetch_vars = metrics.values() fetch_vars = metrics.values()
fetch_alias = metrics.keys() fetch_alias = metrics.keys()
metrics_varnames = [] metrics_varnames = []
metrics_format = [] metrics_format = []
fetch_period = 20
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items(): for name, var in metrics.items():
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
...@@ -312,10 +314,11 @@ class SingleTrainer(TranspileTrainer): ...@@ -312,10 +314,11 @@ class SingleTrainer(TranspileTrainer):
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
feed_varnames = envs.get_global_env( feed_varnames = envs.get_global_env(
name + "save_inference_feed_varnames", None) name + "save_inference_feed_varnames", [])
fetch_varnames = envs.get_global_env( fetch_varnames = envs.get_global_env(
name + "save_inference_fetch_varnames", None) name + "save_inference_fetch_varnames", [])
if feed_varnames is None or fetch_varnames is None or feed_varnames == "": if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return return
fetch_vars = [ fetch_vars = [
fluid.default_main_program().global_block().vars[varname] fluid.default_main_program().global_block().vars[varname]
......
...@@ -68,8 +68,12 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file): ...@@ -68,8 +68,12 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1]) data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env(name + "sparse_slots") sparse = get_global_env(name + "sparse_slots", "#")
dense = get_global_env(name + "dense_slots") if sparse == "":
sparse = "#"
dense = get_global_env(name + "dense_slots", "#")
if dense == "":
dense = "#"
padding = get_global_env(name + "padding", 0) padding = get_global_env(name + "padding", 0)
reader = SlotReader(yaml_file) reader = SlotReader(yaml_file)
reader.init(sparse, dense, int(padding)) reader.init(sparse, dense, int(padding))
...@@ -158,8 +162,12 @@ def slotdataloader(readerclass, train, yaml_file): ...@@ -158,8 +162,12 @@ def slotdataloader(readerclass, train, yaml_file):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)] files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env("sparse_slots", None, namespace) sparse = get_global_env("sparse_slots", "#", namespace)
dense = get_global_env("dense_slots", None, namespace) if sparse == "":
sparse = "#"
dense = get_global_env("dense_slots", "#", namespace)
if dense == "":
dense = "#"
padding = get_global_env("padding", 0, namespace) padding = get_global_env("padding", 0, namespace)
reader = SlotReader(yaml_file) reader = SlotReader(yaml_file)
reader.init(sparse, dense, int(padding)) reader.init(sparse, dense, int(padding))
......
...@@ -62,6 +62,7 @@ runner: ...@@ -62,6 +62,7 @@ runner:
save_inference_feed_varnames: [] # feed vars of save inference save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path init_model_path: "" # load model path
fetch_period: 10
- name: runner2 - name: runner2
class: single_infer class: single_infer
# num of epochs # num of epochs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册