提交 71d07ee9 编写于 作者: S suweiyue

1.rm id2str.npy; 2.term_ids as uint16; 3.fix v3 shape bug

上级 f78f4639
......@@ -49,7 +49,7 @@ ernie_config:
max_position_embeddings: 513
num_attention_heads: 12
num_hidden_layers: 12
sent_type_vocab_size: 4
sent_type_vocab_size: 2
task_type_vocab_size: 3
vocab_size: 18000
use_task_id: false
......
......@@ -49,7 +49,7 @@ ernie_config:
max_position_embeddings: 513
num_attention_heads: 12
num_hidden_layers: 12
sent_type_vocab_size: 4
sent_type_vocab_size: 2
task_type_vocab_size: 3
vocab_size: 18000
use_task_id: false
......
......@@ -80,7 +80,7 @@ class GraphGenerator(BaseDataGenerator):
batch_neg = sampled_batch_neg
if self.phase == "train":
ignore_edges = set()
ignore_edges = np.concatenate([np.stack([batch_src, batch_dst], 1), np.stack([batch_dst, batch_src], 1)], 0)
else:
ignore_edges = set()
......@@ -99,7 +99,7 @@ class GraphGenerator(BaseDataGenerator):
feed_dict["user_index"] = np.array(sub_src_idx, dtype="int64")
feed_dict["item_index"] = np.array(sub_dst_idx, dtype="int64")
feed_dict["neg_item_index"] = np.array(sub_neg_idx, dtype="int64")
feed_dict["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]]
feed_dict["term_ids"] = self.term_ids[subgraphs[0].node_feat["index"]].astype(np.int64)
return feed_dict
def __call__(self):
......
......@@ -59,8 +59,7 @@ def run_predict(py_reader,
log_per_step=1,
args=None):
if args.input_type == "text":
id2str = np.load(os.path.join(args.graph_path, "id2str.npy"), mmap_mode="r")
id2str = io.open(os.path.join(args.graph_path, "terms.txt"), encoding=args.encoding).readlines()
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
trainer_count = int(os.getenv("PADDLE_TRAINERS_NUM", "1"))
......@@ -82,7 +81,7 @@ def run_predict(py_reader,
for ufs, _, sri in zip(batch_usr_feat, batch_ad_feat, batch_src_real_index):
if args.input_type == "text":
sri = id2str[int(sri)]
sri = id2str[int(sri)].strip("\n")
line = "{}\t{}\n".format(sri, tostr(ufs))
fout.write(line)
......
......@@ -342,7 +342,7 @@ class ErnieGraphModel(ErnieModel):
L.range(
0, slot_seqlen, 1, dtype='int32'), [1, slot_seqlen, 1],
inplace=True) # [1, slot_seqlen, 1]
a_position_ids = L.expand(a_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen * num_b, 1]
a_position_ids = L.expand(a_position_ids, [src_batch, 1, 1]) # [B, slot_seqlen, 1]
zero = L.fill_constant([1], dtype='int64', value=0)
input_mask = L.cast(L.equal(src_ids[:, :slot_seqlen], zero), "int32") # assume pad id == 0 [B, slot_seqlen, 1]
......
......@@ -455,18 +455,6 @@ def graph_encoder(enc_input,
attn_bias = build_graph_attn_bias(input_mask, n_head, enc_input.dtype, slot_seqlen)
#attn_bias = build_attn_bias(input_mask, n_head, enc_input.dtype)
# d_batch = d_shape[0]
# d_seqlen = d_shape[1]
# pad_idx = L.where(
# L.cast(L.reshape(input_mask, [d_batch, d_seqlen]), 'bool'))
# attn_bias = L.matmul(
# input_mask, input_mask, transpose_y=True) # [batch, seq, seq]
# attn_bias = (1. - attn_bias) * -10000.
# attn_bias = L.stack([attn_bias] * n_head, 1)
# if attn_bias.dtype != enc_input.dtype:
# attn_bias = L.cast(attn_bias, enc_input.dtype)
def to_2d(t_3d):
t_2d = L.gather_nd(t_3d, pad_idx)
return t_2d
......
......@@ -24,7 +24,6 @@ from models.message_passing import copy_send
class ErnieSageV3(BaseNet):
def __init__(self, config):
super(ErnieSageV3, self).__init__(config)
self.config.layer_type = "ernie_recv_sum"
def build_inputs(self):
inputs = super(ErnieSageV3, self).build_inputs()
......@@ -35,11 +34,10 @@ class ErnieSageV3(BaseNet):
def gnn_layer(self, gw, feature, hidden_size, act, initializer, learning_rate, name):
def ernie_recv(feat):
"""doc"""
# TODO maxlen 400
#pad_value = L.cast(L.assign(input=np.array([0], dtype=np.int32)), "int64")
num_neighbor = self.config.samples[0]
pad_value = L.zeros([1], "int64")
out, _ = L.sequence_pad(feat, pad_value=pad_value, maxlen=10)
out = L.reshape(out, [0, 400])
out, _ = L.sequence_pad(feat, pad_value=pad_value, maxlen=num_neighbor)
out = L.reshape(out, [0, self.config.max_seqlen*num_neighbor])
return out
def erniesage_v3_aggregator(gw, feature, hidden_size, act, initializer, learning_rate, name):
......@@ -73,7 +71,7 @@ class ErnieSageV3(BaseNet):
act,
initializer,
learning_rate=fc_lr,
name="%s_%s" % (self.config.layer_type, i))
name="%s_%s" % ("erniesage_v3", i))
features.append(feature)
return features
......@@ -85,17 +83,16 @@ class ErnieSageV3(BaseNet):
ernie = ErnieGraphModel(
src_ids=feat,
config=ernie_config,
slot_seqlen=self.config.max_seqlen,
name="student_")
slot_seqlen=self.config.max_seqlen)
feat = ernie.get_pooled_output()
fc_lr = self.config.lr / 0.001
feat= L.fc(feat,
self.config.hidden_size,
act="relu",
param_attr=F.ParamAttr(name=name + "_l",
learning_rate=fc_lr),
)
feat = L.l2_normalize(feat, axis=1)
# feat = L.fc(feat,
# self.config.hidden_size,
# act="relu",
# param_attr=F.ParamAttr(name=name + "_l",
# learning_rate=fc_lr),
# )
#feat = L.l2_normalize(feat, axis=1)
if self.config.final_fc:
feat = L.fc(feat,
......
......@@ -36,7 +36,7 @@ from tokenization import FullTokenizer
def term2id(string, tokenizer, max_seqlen):
string = string.split("\t")[1]
#string = string.split("\t")[1]
tokens = tokenizer.tokenize(string)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids = ids[:max_seqlen-1]
......@@ -99,19 +99,13 @@ def dump_graph(args):
np.save(os.path.join(args.outpath, "neg_samples.npy"), np.array(neg_samples))
log.info("End Build Graph")
def dump_id2str_map(args):
log.info("Dump id2str map starting...")
id2str = np.array([line.strip("\n") for line in open(os.path.join(args.outpath, "terms.txt"), "r", encoding=args.encoding)])
np.save(os.path.join(args.outpath, "id2str.npy"), id2str)
log.info("Dump id2str map done.")
def dump_node_feat(args):
log.info("Dump node feat starting...")
id2str = np.load(os.path.join(args.outpath, "id2str.npy"), mmap_mode="r")
id2str = [line.strip("\n").split("\t")[1] for line in io.open(os.path.join(args.outpath, "terms.txt"), encoding=args.encoding)]
pool = multiprocessing.Pool()
tokenizer = FullTokenizer(args.vocab_file)
term_ids = pool.map(partial(term2id, tokenizer=tokenizer, max_seqlen=args.max_seqlen), id2str)
np.save(os.path.join(args.outpath, "term_ids.npy"), np.array(term_ids))
np.save(os.path.join(args.outpath, "term_ids.npy"), np.array(term_ids, np.uint16))
log.info("Dump node feat done.")
pool.terminate()
......@@ -124,5 +118,4 @@ if __name__ == "__main__":
parser.add_argument("-o", "--outpath", type=str, default=None)
args = parser.parse_args()
dump_graph(args)
dump_id2str_map(args)
dump_node_feat(args)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册