提交 a3ddeec0 编写于 作者: X xjqbest

fix

上级 a09255fb
......@@ -59,10 +59,12 @@ class SlotReader(dg.MultiSlotDataGenerator):
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
self.sparse_slots = []
if sparse_slots.strip() != "#":
if sparse_slots.strip() != "#" and sparse_slots.strip(
) != "?" and sparse_slots.strip() != "":
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = []
if dense_slots.strip() != "#":
if dense_slots.strip() != "#" and dense_slots.strip(
) != "?" and dense_slots.strip() != "":
self.dense_slots = dense_slots.strip().split(" ")
self.dense_slots_shape = [
reduce(mul,
......
......@@ -78,14 +78,14 @@ class SingleInfer(TranspileTrainer):
pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else:
if sparse_slots is None:
sparse_slots = "#"
if dense_slots is None:
dense_slots = "#"
if sparse_slots == "":
sparse_slots = "?"
if dense_slots == "":
dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
......@@ -290,7 +290,7 @@ class SingleInfer(TranspileTrainer):
def load(self, is_fleet=False):
name = "runner." + self._runner_name + "."
dirname = envs.get_global_env("epoch.init_model_path", None)
dirname = envs.get_global_env(name + "init_model_path", None)
if dirname is None or dirname == "":
return
print("single_infer going to load ", dirname)
......
......@@ -73,13 +73,13 @@ class SingleTrainer(TranspileTrainer):
"TRAIN", self._config_yaml)
else:
if sparse_slots == "":
sparse_slots = "#"
sparse_slots = "?"
if dense_slots == "":
dense_slots = "#"
dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
......
......@@ -32,8 +32,8 @@ elif sys.argv[2].upper() == "EVALUATE":
else:
reader_name = "SlotReader"
namespace = sys.argv[4]
sparse_slots = sys.argv[5].replace("#", " ")
dense_slots = sys.argv[6].replace("#", " ")
sparse_slots = sys.argv[5].replace("?", " ")
dense_slots = sys.argv[6].replace("?", " ")
padding = int(sys.argv[7])
yaml_abs_path = sys.argv[3]
......
......@@ -36,7 +36,7 @@ class Model(ModelBase):
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
self.dense_input = self._dense_data_var[0]
self.dense_input = [] #self._dense_data_var[0]
self.label_input = self._sparse_data_var[0]
def embedding_layer(input):
......@@ -52,8 +52,8 @@ class Model(ModelBase):
return emb_sum
sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs))
concated = fluid.layers.concat(
sparse_embed_seq + [self.dense_input], axis=1)
concated = fluid.layers.concat(sparse_embed_seq, axis=1)
#sparse_embed_seq + [self.dense_input], axis=1)
fcs = [concated]
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册