提交 a09255fb 编写于 作者: X xjqbest

fix

上级 37a77dcd
......@@ -59,11 +59,17 @@ class Model(object):
dataset = i
break
name = "dataset." + dataset["name"] + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
if sparse_slots is not None or dense_slots is not None:
sparse_slots = sparse_slots.strip().split(" ")
dense_slots = dense_slots.strip().split(" ")
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots != "" or dense_slots != "":
if sparse_slots == "":
sparse_slots = []
else:
sparse_slots = sparse_slots.strip().split(" ")
if dense_slots == "":
dense_slots = []
else:
dense_slots = dense_slots.strip().split(" ")
dense_slots_shape = [[
int(j) for j in i.split(":")[1].strip("[]").split(",")
] for i in dense_slots]
......@@ -151,11 +157,17 @@ class Model(object):
def input_data(self, is_infer=False, **kwargs):
name = "dataset." + kwargs.get("dataset_name") + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
if sparse_slots is not None or dense_slots is not None:
sparse_slots = sparse_slots.strip().split(" ")
dense_slots = dense_slots.strip().split(" ")
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots != "" or dense_slots != "":
if sparse_slots == "":
sparse_slots = []
else:
sparse_slots = sparse_slots.strip().split(" ")
if dense_slots == "":
dense_slots = []
else:
dense_slots = dense_slots.strip().split(" ")
dense_slots_shape = [[
int(j) for j in i.split(":")[1].strip("[]").split(",")
] for i in dense_slots]
......
......@@ -67,15 +67,14 @@ class SingleInfer(TranspileTrainer):
def _get_dataset(self, dataset_name):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None:
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots == "" and dense_slots == "":
pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else:
......@@ -107,13 +106,13 @@ class SingleInfer(TranspileTrainer):
def _get_dataloader(self, dataset_name, dataloader):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None:
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots == "" and dense_slots == "":
reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class,
......@@ -228,7 +227,9 @@ class SingleInfer(TranspileTrainer):
model_class = self._model[model_name][3]
fetch_vars = []
fetch_alias = []
fetch_period = 20
fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_infer_results()
if metrics:
fetch_vars = metrics.values()
......@@ -251,14 +252,15 @@ class SingleInfer(TranspileTrainer):
program = self._model[model_name][0].clone()
fetch_vars = []
fetch_alias = []
fetch_period = 20
metrics = model_class.get_infer_results()
if metrics:
fetch_vars = metrics.values()
fetch_alias = metrics.keys()
metrics_varnames = []
metrics_format = []
fetch_period = 20
fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items():
metrics_varnames.append(var.name)
......
......@@ -61,21 +61,20 @@ class SingleTrainer(TranspileTrainer):
def _get_dataset(self, dataset_name):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__))
reader = os.path.join(abs_dir, '../utils', 'dataset_instance.py')
if sparse_slots is None and dense_slots is None:
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
if sparse_slots != "" and dense_slots != "":
pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else:
if sparse_slots is None:
if sparse_slots == "":
sparse_slots = "#"
if dense_slots is None:
if dense_slots == "":
dense_slots = "#"
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
......@@ -101,13 +100,13 @@ class SingleTrainer(TranspileTrainer):
def _get_dataloader(self, dataset_name, dataloader):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size")
reader_class = envs.get_global_env(name + "data_converter")
abs_dir = os.path.dirname(os.path.abspath(__file__))
if sparse_slots is None and dense_slots is None:
if sparse_slots == "" and dense_slots == "":
reader = dataloader_instance.dataloader_by_name(
reader_class, dataset_name, self._config_yaml)
reader_class = envs.lazy_instance_by_fliename(reader_class,
......@@ -125,8 +124,8 @@ class SingleTrainer(TranspileTrainer):
def _create_dataset(self, dataset_name):
name = "dataset." + dataset_name + "."
sparse_slots = envs.get_global_env(name + "sparse_slots")
dense_slots = envs.get_global_env(name + "dense_slots")
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
thread_num = envs.get_global_env(name + "thread_num")
batch_size = envs.get_global_env(name + "batch_size")
type_name = envs.get_global_env(name + "type")
......@@ -225,7 +224,9 @@ class SingleTrainer(TranspileTrainer):
model_class = self._model[model_name][3]
fetch_vars = []
fetch_alias = []
fetch_period = 20
fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
......@@ -250,14 +251,15 @@ class SingleTrainer(TranspileTrainer):
loss_name=model_class.get_avg_cost().name)
fetch_vars = []
fetch_alias = []
fetch_period = 20
fetch_period = int(
envs.get_global_env("runner." + self._runner_name +
".fetch_period", 20))
metrics = model_class.get_metrics()
if metrics:
fetch_vars = metrics.values()
fetch_alias = metrics.keys()
metrics_varnames = []
metrics_format = []
fetch_period = 20
metrics_format.append("{}: {{}}".format("batch"))
for name, var in metrics.items():
metrics_varnames.append(var.name)
......@@ -312,10 +314,11 @@ class SingleTrainer(TranspileTrainer):
if not need_save(epoch_id, save_interval, False):
return
feed_varnames = envs.get_global_env(
name + "save_inference_feed_varnames", None)
name + "save_inference_feed_varnames", [])
fetch_varnames = envs.get_global_env(
name + "save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
name + "save_inference_fetch_varnames", [])
if feed_varnames is None or fetch_varnames is None or feed_varnames == "" or fetch_varnames == "" or \
len(feed_varnames) == 0 or len(fetch_varnames) == 0:
return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname]
......
......@@ -68,8 +68,12 @@ def slotdataloader_by_name(readerclass, dataset_name, yaml_file):
data_path = os.path.join(package_base, data_path.split("::")[1])
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env(name + "sparse_slots")
dense = get_global_env(name + "dense_slots")
sparse = get_global_env(name + "sparse_slots", "#")
if sparse == "":
sparse = "#"
dense = get_global_env(name + "dense_slots", "#")
if dense == "":
dense = "#"
padding = get_global_env(name + "padding", 0)
reader = SlotReader(yaml_file)
reader.init(sparse, dense, int(padding))
......@@ -158,8 +162,12 @@ def slotdataloader(readerclass, train, yaml_file):
files = [str(data_path) + "/%s" % x for x in os.listdir(data_path)]
sparse = get_global_env("sparse_slots", None, namespace)
dense = get_global_env("dense_slots", None, namespace)
sparse = get_global_env("sparse_slots", "#", namespace)
if sparse == "":
sparse = "#"
dense = get_global_env("dense_slots", "#", namespace)
if dense == "":
dense = "#"
padding = get_global_env("padding", 0, namespace)
reader = SlotReader(yaml_file)
reader.init(sparse, dense, int(padding))
......
......@@ -62,6 +62,7 @@ runner:
save_inference_feed_varnames: [] # feed vars of save inference
save_inference_fetch_varnames: [] # fetch vars of save inference
init_model_path: "" # load model path
fetch_period: 10
- name: runner2
class: single_infer
# num of epochs
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册