提交 8228c894 编写于 作者: Y yaoxuefeng

update din yaml

上级 0ffc1dfe
......@@ -36,7 +36,9 @@ hyper_parameters:
item_emb_size: 64
cat_emb_size: 64
is_sparse: False
config_path: "{workspace}/data/config.txt"
item_count: 63001
cat_count: 801
act: "sigmoid"
......@@ -52,6 +54,7 @@ runner:
save_inference_interval: 1
save_checkpoint_path: "increment"
save_inference_path: "inference"
print_interval: 1
- name: infer_runner
trainer_class: single_infer
epochs: 1
......
......@@ -31,10 +31,11 @@ class Model(ModelBase):
self.is_sparse = envs.get_global_env("hyper_parameters.is_sparse",
False)
#significant for speeding up the training process
self.config_path = envs.get_global_env("hyper_parameters.config_path",
"data/config.txt")
self.use_DataLoader = envs.get_global_env(
"hyper_parameters.use_DataLoader", False)
self.item_count = envs.get_global_env("hyper_parameters.item_count",
63001)
self.cat_count = envs.get_global_env("hyper_parameters.cat_count", 801)
def input_data(self, is_infer=False, **kwargs):
seq_len = -1
......@@ -74,13 +75,6 @@ class Model(ModelBase):
] + [label] + [mask] + [target_item_seq] + [target_cat_seq]
return train_inputs
def config_read(self, config_path):
with open(config_path, "r") as fin:
user_count = int(fin.readline().strip())
item_count = int(fin.readline().strip())
cat_count = int(fin.readline().strip())
return user_count, item_count, cat_count
def din_attention(self, hist, target_expand, mask):
"""activation weight"""
......@@ -121,50 +115,48 @@ class Model(ModelBase):
target_item_seq = inputs[6]
target_cat_seq = inputs[7]
user_count, item_count, cat_count = self.config_read(self.config_path)
item_emb_attr = fluid.ParamAttr(name="item_emb")
cat_emb_attr = fluid.ParamAttr(name="cat_emb")
hist_item_emb = fluid.embedding(
input=hist_item_seq,
size=[item_count, self.item_emb_size],
size=[self.item_count, self.item_emb_size],
param_attr=item_emb_attr,
is_sparse=self.is_sparse)
hist_cat_emb = fluid.embedding(
input=hist_cat_seq,
size=[cat_count, self.cat_emb_size],
size=[self.cat_count, self.cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=self.is_sparse)
target_item_emb = fluid.embedding(
input=target_item,
size=[item_count, self.item_emb_size],
size=[self.item_count, self.item_emb_size],
param_attr=item_emb_attr,
is_sparse=self.is_sparse)
target_cat_emb = fluid.embedding(
input=target_cat,
size=[cat_count, self.cat_emb_size],
size=[self.cat_count, self.cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=self.is_sparse)
target_item_seq_emb = fluid.embedding(
input=target_item_seq,
size=[item_count, self.item_emb_size],
size=[self.item_count, self.item_emb_size],
param_attr=item_emb_attr,
is_sparse=self.is_sparse)
target_cat_seq_emb = fluid.embedding(
input=target_cat_seq,
size=[cat_count, self.cat_emb_size],
size=[self.cat_count, self.cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=self.is_sparse)
item_b = fluid.embedding(
input=target_item,
size=[item_count, 1],
size=[self.item_count, 1],
param_attr=fluid.initializer.Constant(value=0.0))
hist_seq_concat = fluid.layers.concat(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册