提交 8e6ea314 编写于 作者: S suweiyue

fix position id

上级 cbf4a1a3
...@@ -6,9 +6,9 @@ optimizer_type: "adam" ...@@ -6,9 +6,9 @@ optimizer_type: "adam"
lr: 0.00005 lr: 0.00005
batch_size: 32 batch_size: 32
CPU_NUM: 10 CPU_NUM: 10
epoch: 20 epoch: 3
log_per_step: 1 log_per_step: 10
save_per_step: 100 save_per_step: 1000
output_path: "./output" output_path: "./output"
ckpt_path: "./ernie_base_ckpt" ckpt_path: "./ernie_base_ckpt"
...@@ -31,6 +31,7 @@ final_fc: true ...@@ -31,6 +31,7 @@ final_fc: true
final_l2_norm: true final_l2_norm: true
loss_type: "hinge" loss_type: "hinge"
margin: 0.3 margin: 0.3
neg_type: "random_neg"
# infer config ------ # infer config ------
infer_model: "./output/last" infer_model: "./output/last"
......
...@@ -86,6 +86,7 @@ class GraphGenerator(BaseDataGenerator): ...@@ -86,6 +86,7 @@ class GraphGenerator(BaseDataGenerator):
nodes = np.unique(np.concatenate([batch_src, batch_dst, batch_neg], 0)) nodes = np.unique(np.concatenate([batch_src, batch_dst, batch_neg], 0))
subgraphs = graphsage_sample(self.graph, nodes, self.samples, ignore_edges=ignore_edges) subgraphs = graphsage_sample(self.graph, nodes, self.samples, ignore_edges=ignore_edges)
#subgraphs[0].reindex_to_parrent_nodes(subgraphs[0].nodes)
feed_dict = {} feed_dict = {}
for i in range(self.num_layers): for i in range(self.num_layers):
feed_dict.update(self.graph_wrappers[i].to_feed(subgraphs[i])) feed_dict.update(self.graph_wrappers[i].to_feed(subgraphs[i]))
...@@ -97,7 +98,7 @@ class GraphGenerator(BaseDataGenerator): ...@@ -97,7 +98,7 @@ class GraphGenerator(BaseDataGenerator):
feed_dict["user_index"] = np.array(sub_src_idx, dtype="int64") feed_dict["user_index"] = np.array(sub_src_idx, dtype="int64")
feed_dict["item_index"] = np.array(sub_dst_idx, dtype="int64") feed_dict["item_index"] = np.array(sub_dst_idx, dtype="int64")
#feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64") feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
feed_dict["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]] feed_dict["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]]
return feed_dict return feed_dict
......
...@@ -72,7 +72,7 @@ def run_predict(py_reader, ...@@ -72,7 +72,7 @@ def run_predict(py_reader,
for batch_feed_dict in py_reader(): for batch_feed_dict in py_reader():
batch += 1 batch += 1
batch_usr_feat, batch_ad_feat, batch_src_real_index = exe.run( batch_usr_feat, batch_ad_feat, _, batch_src_real_index = exe.run(
program, program,
feed=batch_feed_dict, feed=batch_feed_dict,
fetch_list=model_dict.outputs) fetch_list=model_dict.outputs)
......
...@@ -193,6 +193,7 @@ class CollectiveLearner(Learner): ...@@ -193,6 +193,7 @@ class CollectiveLearner(Learner):
def optimize(self, loss, optimizer_type, lr): def optimize(self, loss, optimizer_type, lr):
optimizer = F.optimizer.Adam(learning_rate=lr) optimizer = F.optimizer.Adam(learning_rate=lr)
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.enable_sequential_execution = True
optimizer = cfleet.distributed_optimizer(optimizer, strategy=dist_strategy) optimizer = cfleet.distributed_optimizer(optimizer, strategy=dist_strategy)
_, param_grads = optimizer.minimize(loss, F.default_startup_program()) _, param_grads = optimizer.minimize(loss, F.default_startup_program())
......
...@@ -50,7 +50,6 @@ transpiler_local_train(){ ...@@ -50,7 +50,6 @@ transpiler_local_train(){
} }
collective_local_train(){ collective_local_train(){
export PATH=./python27-gcc482-gpu/bin/:$PATH
echo `which python` echo `which python`
python -m paddle.distributed.launch train.py --conf $config python -m paddle.distributed.launch train.py --conf $config
python -m paddle.distributed.launch infer.py --conf $config python -m paddle.distributed.launch infer.py --conf $config
...@@ -58,8 +57,7 @@ collective_local_train(){ ...@@ -58,8 +57,7 @@ collective_local_train(){
eval $(parse_yaml $config) eval $(parse_yaml $config)
python3 ./preprocessing/dump_graph.py -i $input_data -o $graph_path --encoding $encoding \ python ./preprocessing/dump_graph.py -i $input_data -o $graph_path --encoding $encoding -l $max_seqlen --vocab_file $ernie_vocab_file
-l $max_seqlen --vocab_file $ernie_vocab_file
if [[ $learner_type == "cpu" ]];then if [[ $learner_type == "cpu" ]];then
transpiler_local_train transpiler_local_train
......
...@@ -129,7 +129,9 @@ class BaseNet(object): ...@@ -129,7 +129,9 @@ class BaseNet(object):
"user_index", shape=[None], dtype="int64", append_batch_size=False) "user_index", shape=[None], dtype="int64", append_batch_size=False)
item_index = L.data( item_index = L.data(
"item_index", shape=[None], dtype="int64", append_batch_size=False) "item_index", shape=[None], dtype="int64", append_batch_size=False)
return [user_index, item_index] neg_item_index = L.data(
"neg_item_index", shape=[None], dtype="int64", append_batch_size=False)
return [user_index, item_index, neg_item_index]
def build_embedding(self, graph_wrappers, inputs=None): def build_embedding(self, graph_wrappers, inputs=None):
num_embed = int(np.load(os.path.join(self.config.graph_path, "num_nodes.npy"))) num_embed = int(np.load(os.path.join(self.config.graph_path, "num_nodes.npy")))
...@@ -177,18 +179,58 @@ class BaseNet(object): ...@@ -177,18 +179,58 @@ class BaseNet(object):
outputs.append(src_real_index) outputs.append(src_real_index)
return inputs, outputs return inputs, outputs
def all_gather(X):
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_num = int(os.getenv("PADDLE_TRAINERS_NUM", "0"))
if trainer_num == 1:
copy_X = X * 1
copy_X.stop_gradients=True
return copy_X
Xs = []
for i in range(trainer_num):
copy_X = X * 1
copy_X = L.collective._broadcast(copy_X, i, True)
copy_X.stop_gradients=True
Xs.append(copy_X)
if len(Xs) > 1:
Xs=L.concat(Xs, 0)
Xs.stop_gradients=True
else:
Xs = Xs[0]
return Xs
class BaseLoss(object): class BaseLoss(object):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
def __call__(self, outputs): def __call__(self, outputs):
user_feat, item_feat = outputs[0], outputs[1] user_feat, item_feat, neg_item_feat = outputs[0], outputs[1], outputs[2]
loss_type = self.config.loss_type loss_type = self.config.loss_type
if self.config.neg_type == "batch_neg":
neg_item_feat = item_feat
# Calc Loss # Calc Loss
if self.config.loss_type == "hinge": if self.config.loss_type == "hinge":
pos = L.reduce_sum(user_feat * item_feat, -1, keep_dim=True) # [B, 1] pos = L.reduce_sum(user_feat * item_feat, -1, keep_dim=True) # [B, 1]
neg = L.matmul(user_feat, item_feat, transpose_y=True) # [B, B] neg = L.matmul(user_feat, neg_item_feat, transpose_y=True) # [B, B]
loss = L.reduce_mean(L.relu(neg - pos + self.config.margin)) loss = L.reduce_mean(L.relu(neg - pos + self.config.margin))
elif self.config.loss_type == "all_hinge":
pos = L.reduce_sum(user_feat * item_feat, -1, keep_dim=True) # [B, 1]
all_pos = all_gather(pos) # [B * n, 1]
all_neg_item_feat = all_gather(neg_item_feat) # [B * n, 1]
all_user_feat = all_gather(user_feat) # [B * n, 1]
neg1 = L.matmul(user_feat, all_neg_item_feat, transpose_y=True) # [B, B * n]
neg2 = L.matmul(all_user_feat, neg_item_feat, transpose_y=True) # [B *n, B]
loss1 = L.reduce_mean(L.relu(neg1 - pos + self.config.margin))
loss2 = L.reduce_mean(L.relu(neg2 - all_pos + self.config.margin))
#loss = (loss1 + loss2) / 2
loss = loss1 + loss2
elif self.config.loss_type == "softmax": elif self.config.loss_type == "softmax":
pass pass
# TODO # TODO
......
...@@ -59,6 +59,8 @@ class ErnieModel(object): ...@@ -59,6 +59,8 @@ class ErnieModel(object):
def __init__(self, def __init__(self,
src_ids, src_ids,
sentence_ids, sentence_ids,
position_ids=None,
input_mask=None,
task_ids=None, task_ids=None,
config=None, config=None,
weight_sharing=True, weight_sharing=True,
...@@ -66,8 +68,10 @@ class ErnieModel(object): ...@@ -66,8 +68,10 @@ class ErnieModel(object):
name=""): name=""):
self._set_config(config, name, weight_sharing) self._set_config(config, name, weight_sharing)
input_mask = self._build_input_mask(src_ids) if position_ids is None:
position_ids = self._build_position_ids(src_ids) position_ids = self._build_position_ids(src_ids)
if input_mask is None:
input_mask = self._build_input_mask(src_ids)
self._build_model(src_ids, position_ids, sentence_ids, task_ids, self._build_model(src_ids, position_ids, sentence_ids, task_ids,
input_mask) input_mask)
self._debug_summary(input_mask) self._debug_summary(input_mask)
......
...@@ -3,8 +3,6 @@ import paddle.fluid as F ...@@ -3,8 +3,6 @@ import paddle.fluid as F
import paddle.fluid.layers as L import paddle.fluid.layers as L
from models.base import BaseNet, BaseGNNModel from models.base import BaseNet, BaseGNNModel
from models.ernie_model.ernie import ErnieModel from models.ernie_model.ernie import ErnieModel
from models.ernie_model.ernie import ErnieGraphModel
from models.ernie_model.ernie import ErnieConfig
class ErnieSageV2(BaseNet): class ErnieSageV2(BaseNet):
...@@ -16,19 +14,52 @@ class ErnieSageV2(BaseNet): ...@@ -16,19 +14,52 @@ class ErnieSageV2(BaseNet):
return inputs + [term_ids] return inputs + [term_ids]
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name): def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
def build_position_ids(src_ids, dst_ids):
src_shape = L.shape(src_ids)
src_batch = src_shape[0]
src_seqlen = src_shape[1]
dst_seqlen = src_seqlen - 1 # without cls
src_position_ids = L.reshape(
L.range(
0, src_seqlen, 1, dtype='int32'), [1, src_seqlen, 1],
inplace=True) # [1, slot_seqlen, 1]
src_position_ids = L.expand(src_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1]
zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.equal(src_ids, zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1]
src_pad_len = L.reduce_sum(input_mask, 1) # [B, 1, 1]
dst_position_ids = L.reshape(
L.range(
src_seqlen, src_seqlen+dst_seqlen, 1, dtype='int32'), [1, dst_seqlen, 1],
inplace=True) # [1, slot_seqlen, 1]
dst_position_ids = L.expand(dst_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen, 1]
dst_position_ids = dst_position_ids - src_pad_len # [B, slot_seqlen, 1]
position_ids = L.concat([src_position_ids, dst_position_ids], 1)
position_ids = L.cast(position_ids, 'int64')
position_ids.stop_gradient = True
return position_ids
def ernie_send(src_feat, dst_feat, edge_feat): def ernie_send(src_feat, dst_feat, edge_feat):
"""doc""" """doc"""
# input_ids
cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1) cls = L.fill_constant_batch_size_like(src_feat["term_ids"], [-1, 1, 1], "int64", 1)
src_ids = L.concat([cls, src_feat["term_ids"]], 1) src_ids = L.concat([cls, src_feat["term_ids"]], 1)
dst_ids = dst_feat["term_ids"] dst_ids = dst_feat["term_ids"]
# sent_ids
sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1) sent_ids = L.concat([L.zeros_like(src_ids), L.ones_like(dst_ids)], 1)
term_ids = L.concat([src_ids, dst_ids], 1) term_ids = L.concat([src_ids, dst_ids], 1)
# position_ids
position_ids = build_position_ids(src_ids, dst_ids)
term_ids.stop_gradient = True term_ids.stop_gradient = True
sent_ids.stop_gradient = True sent_ids.stop_gradient = True
ernie = ErnieModel( ernie = ErnieModel(
term_ids, sent_ids, term_ids, sent_ids, position_ids,
config=self.config.ernie_config) config=self.config.ernie_config)
feature = ernie.get_pooled_output() feature = ernie.get_pooled_output()
return feature return feature
......
...@@ -18,7 +18,6 @@ import paddle.fluid.layers as L ...@@ -18,7 +18,6 @@ import paddle.fluid.layers as L
from models.base import BaseNet, BaseGNNModel from models.base import BaseNet, BaseGNNModel
from models.ernie_model.ernie import ErnieModel from models.ernie_model.ernie import ErnieModel
from models.ernie_model.ernie import ErnieGraphModel from models.ernie_model.ernie import ErnieGraphModel
from models.ernie_model.ernie import ErnieConfig
from models.message_passing import copy_send from models.message_passing import copy_send
......
...@@ -53,6 +53,7 @@ def dump_graph(args): ...@@ -53,6 +53,7 @@ def dump_graph(args):
term_file = io.open(os.path.join(args.outpath, "terms.txt"), "w", encoding=args.encoding) term_file = io.open(os.path.join(args.outpath, "terms.txt"), "w", encoding=args.encoding)
terms = [] terms = []
count = 0 count = 0
item_distribution = []
with io.open(args.inpath, encoding=args.encoding) as f: with io.open(args.inpath, encoding=args.encoding) as f:
edges = [] edges = []
...@@ -66,6 +67,7 @@ def dump_graph(args): ...@@ -66,6 +67,7 @@ def dump_graph(args):
str2id[s] = count str2id[s] = count
count += 1 count += 1
term_file.write(str(col_idx) + "\t" + col + "\n") term_file.write(str(col_idx) + "\t" + col + "\n")
item_distribution.append(0)
slots.append(str2id[s]) slots.append(str2id[s])
...@@ -74,6 +76,7 @@ def dump_graph(args): ...@@ -74,6 +76,7 @@ def dump_graph(args):
neg_samples.append(slots[2:]) neg_samples.append(slots[2:])
edges.append((src, dst)) edges.append((src, dst))
edges.append((dst, src)) edges.append((dst, src))
item_distribution[dst] += 1
term_file.close() term_file.close()
edges = np.array(edges, dtype="int64") edges = np.array(edges, dtype="int64")
...@@ -82,12 +85,14 @@ def dump_graph(args): ...@@ -82,12 +85,14 @@ def dump_graph(args):
log.info("building graph...") log.info("building graph...")
graph = pgl.graph.Graph(num_nodes=num_nodes, edges=edges) graph = pgl.graph.Graph(num_nodes=num_nodes, edges=edges)
indegree = graph.indegree() indegree = graph.indegree()
graph.indegree()
graph.outdegree() graph.outdegree()
graph.dump(args.outpath) graph.dump(args.outpath)
# dump alias sample table # dump alias sample table
sqrt_indegree = np.sqrt(indegree) item_distribution = np.array(item_distribution)
distribution = 1. * sqrt_indegree / sqrt_indegree.sum() item_distribution = np.sqrt(item_distribution)
distribution = 1. * item_distribution / item_distribution.sum()
alias, events = alias_sample_build_table(distribution) alias, events = alias_sample_build_table(distribution)
np.save(os.path.join(args.outpath, "alias.npy"), alias) np.save(os.path.join(args.outpath, "alias.npy"), alias)
np.save(os.path.join(args.outpath, "events.npy"), events) np.save(os.path.join(args.outpath, "events.npy"), events)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册