提交 734812c3 编写于 作者: H hutuxian 提交者: Yi Liu

upgrade API for DIN (#3499)

上级 d159a6b3
......@@ -14,7 +14,7 @@
import paddle.fluid as fluid
def din_attention(hist, target_expand, max_len, mask):
def din_attention(hist, target_expand, mask):
"""activation weight"""
hidden_size = hist.shape[-1]
......@@ -45,9 +45,10 @@ def din_attention(hist, target_expand, max_len, mask):
return out
def network(item_count, cat_count, max_len):
def network(item_count, cat_count):
"""network definition"""
seq_len = -1
item_emb_size = 64
cat_emb_size = 64
is_sparse = False
......@@ -56,60 +57,60 @@ def network(item_count, cat_count, max_len):
item_emb_attr = fluid.ParamAttr(name="item_emb")
cat_emb_attr = fluid.ParamAttr(name="cat_emb")
hist_item_seq = fluid.layers.data(
name="hist_item_seq", shape=[max_len, 1], dtype="int64")
hist_cat_seq = fluid.layers.data(
name="hist_cat_seq", shape=[max_len, 1], dtype="int64")
target_item = fluid.layers.data(
name="target_item", shape=[1], dtype="int64")
target_cat = fluid.layers.data(
name="target_cat", shape=[1], dtype="int64")
label = fluid.layers.data(
name="label", shape=[1], dtype="float32")
mask = fluid.layers.data(
name="mask", shape=[max_len, 1], dtype="float32")
target_item_seq = fluid.layers.data(
name="target_item_seq", shape=[max_len, 1], dtype="int64")
target_cat_seq = fluid.layers.data(
name="target_cat_seq", shape=[max_len, 1], dtype="int64", lod_level=0)
hist_item_emb = fluid.layers.embedding(
hist_item_seq = fluid.data(
name="hist_item_seq", shape=[None, seq_len], dtype="int64")
hist_cat_seq = fluid.data(
name="hist_cat_seq", shape=[None, seq_len], dtype="int64")
target_item = fluid.data(
name="target_item", shape=[None], dtype="int64")
target_cat = fluid.data(
name="target_cat", shape=[None], dtype="int64")
label = fluid.data(
name="label", shape=[None, 1], dtype="float32")
mask = fluid.data(
name="mask", shape=[None, seq_len, 1], dtype="float32")
target_item_seq = fluid.data(
name="target_item_seq", shape=[None, seq_len], dtype="int64")
target_cat_seq = fluid.data(
name="target_cat_seq", shape=[None, seq_len], dtype="int64")
hist_item_emb = fluid.embedding(
input=hist_item_seq,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
hist_cat_emb = fluid.layers.embedding(
hist_cat_emb = fluid.embedding(
input=hist_cat_seq,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
target_item_emb = fluid.layers.embedding(
target_item_emb = fluid.embedding(
input=target_item,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
target_cat_emb = fluid.layers.embedding(
target_cat_emb = fluid.embedding(
input=target_cat,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
target_item_seq_emb = fluid.layers.embedding(
target_item_seq_emb = fluid.embedding(
input=target_item_seq,
size=[item_count, item_emb_size],
param_attr=item_emb_attr,
is_sparse=is_sparse)
target_cat_seq_emb = fluid.layers.embedding(
target_cat_seq_emb = fluid.embedding(
input=target_cat_seq,
size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr,
is_sparse=is_sparse)
item_b = fluid.layers.embedding(
item_b = fluid.embedding(
input=target_item,
size=[item_count, 1],
param_attr=fluid.initializer.Constant(value=0.0))
......@@ -120,7 +121,7 @@ def network(item_count, cat_count, max_len):
target_concat = fluid.layers.concat(
[target_item_emb, target_cat_emb], axis=1)
out = din_attention(hist_seq_concat, target_seq_concat, max_len, mask)
out = din_attention(hist_seq_concat, target_seq_concat, mask)
out_fc = fluid.layers.fc(name="out_fc",
input=out,
size=item_emb_size + cat_emb_size,
......
......@@ -20,7 +20,7 @@ import pickle
def pad_batch_data(input, max_len):
res = np.array([x + [0] * (max_len - len(x)) for x in input])
res = res.astype("int64").reshape([-1, max_len, 1])
res = res.astype("int64").reshape([-1, max_len])
return res
......@@ -34,10 +34,10 @@ def make_data(b):
[-1, max_len, 1])
target_item_seq = np.array(
[[x[2]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1])
[-1, max_len])
target_cat_seq = np.array(
[[x[3]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1])
[-1, max_len])
res = []
for i in range(len(b)):
res.append([
......
......@@ -78,7 +78,7 @@ def train():
args.num_devices)
logger.info("reading data completes")
avg_cost, pred = network.network(item_count, cat_count, max_len)
avg_cost, pred = network.network(item_count, cat_count)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0))
base_lr = args.base_lr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册