提交 150b2886 编写于 作者: M malin10

Merge branch 'yaml1' of https://github.com/xjqbest/PaddleRec into modify_yaml

......@@ -159,6 +159,8 @@ class Model(object):
name = "dataset." + kwargs.get("dataset_name") + "."
sparse_slots = envs.get_global_env(name + "sparse_slots", "").strip()
dense_slots = envs.get_global_env(name + "dense_slots", "").strip()
self._sparse_data_var_map = {}
self._dense_data_var_map = {}
if sparse_slots != "" or dense_slots != "":
if sparse_slots == "":
sparse_slots = []
......@@ -181,12 +183,14 @@ class Model(object):
dtype="float32")
data_var_.append(l)
self._dense_data_var.append(l)
self._dense_data_var_map[dense_slots[i]] = l
self._sparse_data_var = []
for name in sparse_slots:
l = fluid.layers.data(
name=name, shape=[1], lod_level=1, dtype="int64")
data_var_.append(l)
self._sparse_data_var.append(l)
self._sparse_data_var_map[name] = l
return data_var_
else:
......
......@@ -59,10 +59,12 @@ class SlotReader(dg.MultiSlotDataGenerator):
def init(self, sparse_slots, dense_slots, padding=0):
from operator import mul
self.sparse_slots = []
if sparse_slots.strip() != "#":
if sparse_slots.strip() != "#" and sparse_slots.strip(
) != "?" and sparse_slots.strip() != "":
self.sparse_slots = sparse_slots.strip().split(" ")
self.dense_slots = []
if dense_slots.strip() != "#":
if dense_slots.strip() != "#" and dense_slots.strip(
) != "?" and dense_slots.strip() != "":
self.dense_slots = dense_slots.strip().split(" ")
self.dense_slots_shape = [
reduce(mul,
......
......@@ -78,14 +78,14 @@ class SingleInfer(TranspileTrainer):
pipe_cmd = "python {} {} {} {}".format(reader, reader_class,
"TRAIN", self._config_yaml)
else:
if sparse_slots is None:
sparse_slots = "#"
if dense_slots is None:
dense_slots = "#"
if sparse_slots == "":
sparse_slots = "?"
if dense_slots == "":
dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
......@@ -290,7 +290,7 @@ class SingleInfer(TranspileTrainer):
def load(self, is_fleet=False):
name = "runner." + self._runner_name + "."
dirname = envs.get_global_env("epoch.init_model_path", None)
dirname = envs.get_global_env(name + "init_model_path", None)
if dirname is None or dirname == "":
return
print("single_infer going to load ", dirname)
......
......@@ -73,13 +73,13 @@ class SingleTrainer(TranspileTrainer):
"TRAIN", self._config_yaml)
else:
if sparse_slots == "":
sparse_slots = "#"
sparse_slots = "?"
if dense_slots == "":
dense_slots = "#"
dense_slots = "?"
padding = envs.get_global_env(name + "padding", 0)
pipe_cmd = "python {} {} {} {} {} {} {} {}".format(
reader, "slot", "slot", self._config_yaml, "fake", \
sparse_slots.replace(" ", "#"), dense_slots.replace(" ", "#"), str(padding))
sparse_slots.replace(" ", "?"), dense_slots.replace(" ", "?"), str(padding))
dataset = fluid.DatasetFactory().create_dataset()
dataset.set_batch_size(envs.get_global_env(name + "batch_size"))
......
......@@ -32,8 +32,8 @@ elif sys.argv[2].upper() == "EVALUATE":
else:
reader_name = "SlotReader"
namespace = sys.argv[4]
sparse_slots = sys.argv[5].replace("#", " ")
dense_slots = sys.argv[6].replace("#", " ")
sparse_slots = sys.argv[5].replace("?", " ")
dense_slots = sys.argv[6].replace("?", " ")
padding = int(sys.argv[7])
yaml_abs_path = sys.argv[3]
......
```
# 全局配置
debug: false
workspace: "."
# 用户可以配多个dataset,exector里不同阶段可以用不同的dataset
dataset:
- name: sample_1
type: DataLoader #或者QueueDataset
batch_size: 5
data_path: "{workspace}/data/train"
# 用户自定义reader
data_converter: "{workspace}/rsc15_reader.py"
- name: sample_2
type: QueueDataset #或者DataLoader
batch_size: 5
data_path: "{workspace}/data/train"
# 用户可以配置sparse_slots和dense_slots,无需再定义data_converter
sparse_slots: "click ins_weight 6001 6002 6003 6005 6006 6007 6008 6009"
dense_slots: "readlist:9"
#示例一,用户自定义参数,用于组网配置
hyper_parameters:
#优化器
optimizer:
class: Adam
learning_rate: 0.001
strategy: "{workspace}/conf/config_fleet.py"
# 用户自定义配置
vocab_size: 1000
hid_size: 100
my_key1: 233
my_key2: 0.1
mode: runner1
runner:
- name: runner1 # 示例一,train
trainer_class: single_train
epochs: 10
device: cpu
init_model_path: ""
save_checkpoint_interval: 2
save_inference_interval: 4
# 下面是保存模型路径配置
save_checkpoint_path: "xxxx"
save_inference_path: "xxxx"
- name: runner2 # 示例二,infer
trainer_class: single_train
epochs: 1
device: cpu
init_model_path: "afs:/xxx/xxx"
phase:
- name: phase1
model: "{workspace}/model.py"
dataset_name: sample_1
thread_num: 1
```
......@@ -36,7 +36,7 @@ class Model(ModelBase):
def net(self, input, is_infer=False):
self.sparse_inputs = self._sparse_data_var[1:]
self.dense_input = self._dense_data_var[0]
self.dense_input = [] #self._dense_data_var[0]
self.label_input = self._sparse_data_var[0]
def embedding_layer(input):
......@@ -52,8 +52,8 @@ class Model(ModelBase):
return emb_sum
sparse_embed_seq = list(map(embedding_layer, self.sparse_inputs))
concated = fluid.layers.concat(
sparse_embed_seq + [self.dense_input], axis=1)
concated = fluid.layers.concat(sparse_embed_seq, axis=1)
#sparse_embed_seq + [self.dense_input], axis=1)
fcs = [concated]
hidden_layers = envs.get_global_env("hyper_parameters.fc_sizes")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册