提交 a3ddeec0 编写于 作者: X xjqbest

fix

上级 a09255fb
...@@ -59,10 +59,12 @@ class SlotReader(dg.MultiSlotDataGenerator): ...@@ -59,10 +59,12 @@ class SlotReader(dg.MultiSlotDataGenerator):
def init(self, sparse_slots, dense_slots, padding=0): def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul from operator import mul
self.sparse_slots = [] self.sparse_slots = []
if sparse_slots.strip() != "#": if sparse_slots.strip() != "#" and sparse_slots.strip(
) != "?" and sparse_slots.strip() != "":
self.sparse_slots = sparse_slots.strip().split(" ") self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = [] self.dense_slots = []
if dense_slots.strip() != "#": if dense_slots.strip() != "#" and dense_slots.strip(
) != "?" and dense_slots.strip() != "":
self.dense_slots = dense_slots.strip().split(" ") self.dense_slots = dense_slots.strip().split(" ")
self.dense_slots_shape = [ self.dense_slots_shape = [
reduce(mul, reduce(mul,
......
...@@ -78,14 +78,14 @@ class SingleInfer(TranspileTrainer): ...@@ -78,14 +78,14 @@ class SingleInfer(TranspileTrainer):
pipe_cmd = "python {} {} {} {}".format(reader, reader_class, pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml) "TRAIN", self._config_yaml)
else: else:
if sparse_slots is None: if sparse_slots == "":
sparse_slots = "#" sparse_slots = "?"
if dense_slots is None: if dense_slots == "":
dense_slots = "#" dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0) padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format( pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \ reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding)) sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size")) dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
...@@ -290,7 +290,7 @@ class SingleInfer(TranspileTrainer): ...@@ -290,7 +290,7 @@ class SingleInfer(TranspileTrainer):
def load(self, is_fleet=False): def load(self, is_fleet=False):
name = "runner." + self._runner_name + "." name = "runner." + self._runner_name + "."
dirname = envs.get_global_env("epoch.init_model_path", None) dirname = envs.get_global_env(name + "init_model_path", None)
if dirname is None or dirname == "": if dirname is None or dirname == "":
return return
print("single_infer going to load ", dirname) print("single_infer going to load ", dirname)
......
...@@ -73,13 +73,13 @@ class SingleTrainer(TranspileTrainer): ...@@ -73,13 +73,13 @@ class SingleTrainer(TranspileTrainer):
"TRAIN", self._config_yaml) "TRAIN", self._config_yaml)
else: else:
if sparse_slots == "": if sparse_slots == "":
sparse_slots = "#" sparse_slots = "?"
if dense_slots == "": if dense_slots == "":
dense_slots = "#" dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0) padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format( pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \ reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding)) sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset() dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size")) dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
......
...@@ -32,8 +32,8 @@ elif sys.argv[2].upper() == "EVALUATE": ...@@ -32,8 +32,8 @@ elif sys.argv[2].upper() == "EVALUATE":
else: else:
reader_name = "SlotReader" reader_name = "SlotReader"
namespace = sys.argv[4] namespace = sys.argv[4]
sparse_slots = sys.argv[5].replace("#", " ") sparse_slots = sys.argv[5].replace("?", " ")
dense_slots = sys.argv[6].replace("#", " ") dense_slots = sys.argv[6].replace("?", " ")
padding = int(sys.argv[7]) padding = int(sys.argv[7])
yaml_abs_path = sys.argv[3] yaml_abs_path = sys.argv[3]
......
...@@ -36,7 +36,7 @@ class Model(ModelBase): ...@@ -36,7 +36,7 @@ class Model(ModelBase):
def net(self, input, is_infer=False): def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:] self.sparse_inputs = self._sparse_data_var[1:]
self.dense_input = self._dense_data_var[0] self.dense_input = [] #self._dense_data_var[0]
self.label_input = self._sparse_data_var[0] self.label_input = self._sparse_data_var[0]
def embedding_layer(input): def embedding_layer(input):
...@@ -52,8 +52,8 @@ class Model(ModelBase): ...@@ -52,8 +52,8 @@ class Model(ModelBase):
return emb_sum return emb_sum
sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs)) sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs))
concated = fluid.layers.concat( concated = fluid.layers.concat(sparse_embed_seq, axis=1)
sparse_embed_seq + [self.dense_input], axis=1) #sparse_embed_seq + [self.dense_input], axis=1)
fcs = [concated] fcs = [concated]
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes") hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册