提交 80b01801 编写于 作者: H hutuxian 提交者: Yi Liu

Upgrade API for gnn (#3503)

上级 734812c3
......@@ -55,7 +55,7 @@ def infer(args):
test_data = reader.Data(args.test_path, False)
place = fluid.CUDAPlace(0) if args.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
loss, acc, py_reader, feed_datas = network.network(items_num, args.hidden_size, args.step)
loss, acc, py_reader, feed_datas = network.network(items_num, args.hidden_size, args.step, batch_size)
exe.run(fluid.default_startup_program())
infer_program = fluid.default_main_program().clone(for_test=True)
......
......@@ -19,36 +19,36 @@ import paddle.fluid as fluid
import paddle.fluid.layers as layers
def network(items_num, hidden_size, step):
def network(items_num, hidden_size, step, bs):
stdv = 1.0 / math.sqrt(hidden_size)
items = layers.data(
items = fluid.data(
name="items",
shape=[1, 1],
dtype="int64") #[batch_size, uniq_max, 1]
seq_index = layers.data(
shape=[bs, -1],
dtype="int64") #[batch_size, uniq_max]
seq_index = fluid.data(
name="seq_index",
shape=[1],
dtype="int32") #[batch_size, seq_max]
last_index = layers.data(
shape=[bs, -1, 2],
dtype="int32") #[batch_size, seq_max, 2]
last_index = fluid.data(
name="last_index",
shape=[1],
dtype="int32") #[batch_size, 1]
adj_in = layers.data(
shape=[bs, 2],
dtype="int32") #[batch_size, 2]
adj_in = fluid.data(
name="adj_in",
shape=[1,1],
shape=[bs, -1, -1],
dtype="float32") #[batch_size, seq_max, seq_max]
adj_out = layers.data(
adj_out = fluid.data(
name="adj_out",
shape=[1,1],
shape=[bs, -1, -1],
dtype="float32") #[batch_size, seq_max, seq_max]
mask = layers.data(
mask = fluid.data(
name="mask",
shape=[1, 1],
shape=[bs, -1, 1],
dtype="float32") #[batch_size, seq_max, 1]
label = layers.data(
label = fluid.data(
name="label",
shape=[1],
shape=[bs, 1],
dtype="int64") #[batch_size, 1]
datas = [items, seq_index, last_index, adj_in, adj_out, mask, label]
......@@ -57,19 +57,17 @@ def network(items_num, hidden_size, step):
feed_datas = fluid.layers.read_file(py_reader)
items, seq_index, last_index, adj_in, adj_out, mask, label = feed_datas
items_emb = layers.embedding(
items_emb = fluid.embedding(
input=items,
param_attr=fluid.ParamAttr(
name="emb",
initializer=fluid.initializer.Uniform(
low=-stdv, high=stdv)),
size=[items_num, hidden_size]) #[batch_size, uniq_max, h]
items_emb_shape = layers.shape(items_emb)
pre_state = items_emb
for i in range(step):
pre_state = layers.reshape(
x=pre_state, shape=[-1, 1, hidden_size], actual_shape=items_emb_shape)
pre_state = layers.reshape(x=pre_state, shape=[bs, -1, hidden_size])
state_in = layers.fc(
input=pre_state,
name="state_in",
......@@ -104,24 +102,12 @@ def network(items_num, hidden_size, step):
bias_attr=False)
pre_state, _, _ = fluid.layers.gru_unit(
input=gru_fc,
hidden=layers.reshape(
x=pre_state, shape=[-1, hidden_size]),
hidden=layers.reshape(x=pre_state, shape=[-1, hidden_size]),
size=3 * hidden_size)
final_state = pre_state #[batch_size * uniq_max, h]
seq_origin_shape = layers.assign(np.array([0,0,hidden_size-1]).astype("int32"))
seq_origin_shape += layers.shape(layers.unsqueeze(seq_index,[2])) #value: [batch_size, seq_max, h]
seq_origin_shape.stop_gradient = True
seq_index = layers.reshape(seq_index, shape=[-1])
seq = layers.gather(final_state, seq_index) #[batch_size * seq_max, h]
last = layers.gather(final_state, last_index) #[batch_size, h]
seq = layers.reshape(
seq, shape=[-1, 1, hidden_size], actual_shape=seq_origin_shape) #[batch_size, seq_max, h]
last = layers.reshape(
last, shape=[-1, hidden_size]) #[batch_size, h]
final_state = layers.reshape(pre_state, shape=[bs, -1, hidden_size])
seq = layers.gather_nd(final_state, seq_index)
last = layers.gather_nd(final_state, last_index)
seq_fc = layers.fc(
input=seq,
......@@ -184,13 +170,13 @@ def network(items_num, hidden_size, step):
low=-stdv, high=stdv))) #[batch_size, h]
all_vocab = layers.create_global_var(
shape=[items_num - 1, 1],
shape=[items_num - 1],
value=0,
dtype="int64",
persistable=True,
name="all_vocab")
all_emb = layers.embedding(
all_emb = fluid.embedding(
input=all_vocab,
param_attr=fluid.ParamAttr(
name="emb",
......
......@@ -64,19 +64,19 @@ class Data():
adj_out.append(np.divide(adj.transpose(), u_deg_out).transpose())
seq_index.append(
[np.where(node == i)[0][0] + id * max_uniq_len for i in e[0]])
[[id, np.where(node == i)[0][0]] for i in e[0]])
last_index.append(
np.where(node == e[0][last_id[id]])[0][0] + id * max_uniq_len)
[id, np.where(node == e[0][last_id[id]])[0][0]])
label.append(e[1] - 1)
mask.append([[1] * (last_id[id] + 1) + [0] *
(max_seq_len - last_id[id] - 1)])
id += 1
items = np.array(items).astype("int64").reshape((batch_size, -1, 1))
items = np.array(items).astype("int64").reshape((batch_size, -1))
seq_index = np.array(seq_index).astype("int32").reshape(
(batch_size, -1))
(batch_size, -1, 2))
last_index = np.array(last_index).astype("int32").reshape(
(batch_size))
(batch_size, 2))
adj_in = np.array(adj_in).astype("float32").reshape(
(batch_size, max_uniq_len, max_uniq_len))
adj_out = np.array(adj_out).astype("float32").reshape(
......@@ -110,8 +110,10 @@ class Data():
cur_batch = remain_data[i:i + batch_size]
yield self.make_data(cur_batch, batch_size)
else:
cur_batch = remain_data[i:]
yield self.make_data(cur_batch, group_remain % batch_size)
# Due to fixed batch_size, discard the remaining ins
return
#cur_batch = remain_data[i:]
#yield self.make_data(cur_batch, group_remain % batch_size)
return _reader
......
......@@ -72,7 +72,7 @@ def train():
batch_size = args.batch_size
items_num = reader.read_config(args.config_path)
loss, acc, py_reader, feed_datas = network.network(items_num, args.hidden_size,
args.step)
args.step, batch_size)
data_reader = reader.Data(args.train_path, True)
logger.info("load data complete")
......@@ -96,7 +96,7 @@ def train():
all_vocab = fluid.global_scope().var("all_vocab").get_tensor()
all_vocab.set(
np.arange(1, items_num).astype("int64").reshape((-1, 1)), place)
np.arange(1, items_num).astype("int64").reshape((-1)), place)
feed_list = [e.name for e in feed_datas]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册