提交 734812c3 编写于 作者: H hutuxian 提交者: Yi Liu

upgrade API for DIN (#3499)

上级 d159a6b3
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import paddle.fluid as fluid import paddle.fluid as fluid
def din_attention(hist, target_expand, max_len, mask): def din_attention(hist, target_expand, mask):
"""activation weight""" """activation weight"""
hidden_size = hist.shape[-1] hidden_size = hist.shape[-1]
...@@ -45,9 +45,10 @@ def din_attention(hist, target_expand, max_len, mask): ...@@ -45,9 +45,10 @@ def din_attention(hist, target_expand, max_len, mask):
return out return out
def network(item_count, cat_count, max_len): def network(item_count, cat_count):
"""network definition""" """network definition"""
seq_len = -1
item_emb_size = 64 item_emb_size = 64
cat_emb_size = 64 cat_emb_size = 64
is_sparse = False is_sparse = False
...@@ -56,60 +57,60 @@ def network(item_count, cat_count, max_len): ...@@ -56,60 +57,60 @@ def network(item_count, cat_count, max_len):
item_emb_attr = fluid.ParamAttr(name="item_emb") item_emb_attr = fluid.ParamAttr(name="item_emb")
cat_emb_attr = fluid.ParamAttr(name="cat_emb") cat_emb_attr = fluid.ParamAttr(name="cat_emb")
hist_item_seq = fluid.layers.data( hist_item_seq = fluid.data(
name="hist_item_seq", shape=[max_len, 1], dtype="int64") name="hist_item_seq", shape=[None, seq_len], dtype="int64")
hist_cat_seq = fluid.layers.data( hist_cat_seq = fluid.data(
name="hist_cat_seq", shape=[max_len, 1], dtype="int64") name="hist_cat_seq", shape=[None, seq_len], dtype="int64")
target_item = fluid.layers.data( target_item = fluid.data(
name="target_item", shape=[1], dtype="int64") name="target_item", shape=[None], dtype="int64")
target_cat = fluid.layers.data( target_cat = fluid.data(
name="target_cat", shape=[1], dtype="int64") name="target_cat", shape=[None], dtype="int64")
label = fluid.layers.data( label = fluid.data(
name="label", shape=[1], dtype="float32") name="label", shape=[None, 1], dtype="float32")
mask = fluid.layers.data( mask = fluid.data(
name="mask", shape=[max_len, 1], dtype="float32") name="mask", shape=[None, seq_len, 1], dtype="float32")
target_item_seq = fluid.layers.data( target_item_seq = fluid.data(
name="target_item_seq", shape=[max_len, 1], dtype="int64") name="target_item_seq", shape=[None, seq_len], dtype="int64")
target_cat_seq = fluid.layers.data( target_cat_seq = fluid.data(
name="target_cat_seq", shape=[max_len, 1], dtype="int64", lod_level=0) name="target_cat_seq", shape=[None, seq_len], dtype="int64")
hist_item_emb = fluid.layers.embedding( hist_item_emb = fluid.embedding(
input=hist_item_seq, input=hist_item_seq,
size=[item_count, item_emb_size], size=[item_count, item_emb_size],
param_attr=item_emb_attr, param_attr=item_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
hist_cat_emb = fluid.layers.embedding( hist_cat_emb = fluid.embedding(
input=hist_cat_seq, input=hist_cat_seq,
size=[cat_count, cat_emb_size], size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr, param_attr=cat_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
target_item_emb = fluid.layers.embedding( target_item_emb = fluid.embedding(
input=target_item, input=target_item,
size=[item_count, item_emb_size], size=[item_count, item_emb_size],
param_attr=item_emb_attr, param_attr=item_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
target_cat_emb = fluid.layers.embedding( target_cat_emb = fluid.embedding(
input=target_cat, input=target_cat,
size=[cat_count, cat_emb_size], size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr, param_attr=cat_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
target_item_seq_emb = fluid.layers.embedding( target_item_seq_emb = fluid.embedding(
input=target_item_seq, input=target_item_seq,
size=[item_count, item_emb_size], size=[item_count, item_emb_size],
param_attr=item_emb_attr, param_attr=item_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
target_cat_seq_emb = fluid.layers.embedding( target_cat_seq_emb = fluid.embedding(
input=target_cat_seq, input=target_cat_seq,
size=[cat_count, cat_emb_size], size=[cat_count, cat_emb_size],
param_attr=cat_emb_attr, param_attr=cat_emb_attr,
is_sparse=is_sparse) is_sparse=is_sparse)
item_b = fluid.layers.embedding( item_b = fluid.embedding(
input=target_item, input=target_item,
size=[item_count, 1], size=[item_count, 1],
param_attr=fluid.initializer.Constant(value=0.0)) param_attr=fluid.initializer.Constant(value=0.0))
...@@ -120,7 +121,7 @@ def network(item_count, cat_count, max_len): ...@@ -120,7 +121,7 @@ def network(item_count, cat_count, max_len):
target_concat = fluid.layers.concat( target_concat = fluid.layers.concat(
[target_item_emb, target_cat_emb], axis=1) [target_item_emb, target_cat_emb], axis=1)
out = din_attention(hist_seq_concat, target_seq_concat, max_len, mask) out = din_attention(hist_seq_concat, target_seq_concat, mask)
out_fc = fluid.layers.fc(name="out_fc", out_fc = fluid.layers.fc(name="out_fc",
input=out, input=out,
size=item_emb_size + cat_emb_size, size=item_emb_size + cat_emb_size,
......
...@@ -20,7 +20,7 @@ import pickle ...@@ -20,7 +20,7 @@ import pickle
def pad_batch_data(input, max_len): def pad_batch_data(input, max_len):
res = np.array([x + [0] * (max_len - len(x)) for x in input]) res = np.array([x + [0] * (max_len - len(x)) for x in input])
res = res.astype("int64").reshape([-1, max_len, 1]) res = res.astype("int64").reshape([-1, max_len])
return res return res
...@@ -34,10 +34,10 @@ def make_data(b): ...@@ -34,10 +34,10 @@ def make_data(b):
[-1, max_len, 1]) [-1, max_len, 1])
target_item_seq = np.array( target_item_seq = np.array(
[[x[2]] * max_len for x in b]).astype("int64").reshape( [[x[2]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1]) [-1, max_len])
target_cat_seq = np.array( target_cat_seq = np.array(
[[x[3]] * max_len for x in b]).astype("int64").reshape( [[x[3]] * max_len for x in b]).astype("int64").reshape(
[-1, max_len, 1]) [-1, max_len])
res = [] res = []
for i in range(len(b)): for i in range(len(b)):
res.append([ res.append([
......
...@@ -78,7 +78,7 @@ def train(): ...@@ -78,7 +78,7 @@ def train():
args.num_devices) args.num_devices)
logger.info("reading data completes") logger.info("reading data completes")
avg_cost, pred = network.network(item_count, cat_count, max_len) avg_cost, pred = network.network(item_count, cat_count)
fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm( fluid.clip.set_gradient_clip(clip=fluid.clip.GradientClipByGlobalNorm(
clip_norm=5.0)) clip_norm=5.0))
base_lr = args.base_lr base_lr = args.base_lr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册