提交 432abdc0 编写于 作者: R root

run tdm

上级 ce56b608
# -*- coding=utf-8 -*-
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import six
import argparse
def str2bool(v):
"""
str2bool
"""
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
ArgumentGroup
"""
def __init__(self, parser, title, des):
"""
init
"""
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
"""
add_arg
"""
type = str2bool if type == bool else type
# if type == list: # by dwk
# self._group.add_argument("--" + name, nargs='+', type=int)
# else:
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def parse_args():
"""
parse_args
"""
# global
parser = argparse.ArgumentParser("main")
main_g = ArgumentGroup(parser, "main", "global conf")
main_g.add_arg("random_seed", int, 0, "random_seed")
main_g.add_arg("cpu_num", int, 1, "cpu_num")
main_g.add_arg("is_local", bool, False,
"whether to perform local training")
main_g.add_arg("is_cloud", bool, False, "")
main_g.add_arg("is_test", bool, False, "")
main_g.add_arg("sync_mode", str, "async","distributed traing mode")
main_g.add_arg("need_trace", bool, False, "")
main_g.add_arg("need_detail", bool, False, "")
# model
model_g = ArgumentGroup(
parser, "model", "options to init, resume and save model.")
model_g.add_arg("epoch_num", int, 3, "number of epochs for train")
model_g.add_arg("batch_size", int, 16, "batch size for train")
model_g.add_arg("learning_rate", float, 5e-5,
"learning rate for global training")
model_g.add_arg("layer_size", int, 4, "layer size")
model_g.add_arg("node_nums", int, 26, "tree node nums")
model_g.add_arg("node_emb_size", int, 64, "node embedding size")
model_g.add_arg("query_emb_size", int, 768, "input query embedding size")
model_g.add_arg("neg_sampling_list", list, [1, 2, 3, 4], "nce sample nums at every layer")
model_g.add_arg("child_nums", int, 2, "child node of ancestor node")
model_g.add_arg("topK", int, 2, "best recall result nums")
model_g = ArgumentGroup(
parser, "path", "files path of data & model.")
model_g.add_arg("train_files_path", str, "./data/train", "train data path")
model_g.add_arg("test_files_path", str, "./data/test", "test data path")
model_g.add_arg("model_files_path", str, "./models", "model data path")
# build tree and warm up
model_g.add_arg("build_tree_init_path", str,
"./data/gen_tree/demo_fake_input.txt", "build tree embedding path")
model_g.add_arg("warm-up", bool, False,
"warm up, builing new tree.")
model_g.add_arg("rebuild_tree_per_epochs", int, -1,
"re-build tree per epochs, -1 means don't re-building")
model_g.add_arg("tree_info_init_path", str,
"./thirdparty/tree_info.txt", "embedding file path")
model_g.add_arg("tree_travel_init_path", str,
"./thirdparty/travel_list.txt", "TDM tree travel file path")
model_g.add_arg("tree_layer_init_path", str,
"./thirdparty/layer_list.txt", "TDM tree layer file path")
model_g.add_arg("tree_id2item_init_path", str,
"./thirdparty/id2item.json", "item_id to item(feasign) mapping file path")
model_g.add_arg("load_model", bool, False,
"whether load model(paddle persistables model)")
model_g.add_arg("save_init_model", bool, False,
"whether save init model(paddle persistables model)")
model_g.add_arg("init_model_files_path", str, "./models/epoch_0",
"init model params by paddle model files for training")
args = parser.parse_args()
return args
def print_arguments(args):
"""
print arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
# -*- coding=utf8 -*-
"""
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
import json
import pickle
import time
import random
import os
import numpy as np
import sys
import paddle.fluid.incubate.data_generator as dg
class TDMDataset(dg.MultiSlotStringDataGenerator):
"""
DacDataset: inheritance MultiSlotDataGeneratior, Implement data reading
Help document: http://wiki.baidu.com/pages/viewpage.action?pageId=728820675
"""
def infer_reader(self, infer_file_list, batch):
"""
Read test_data line by line & yield batch
"""
def local_iter():
"""Read file line by line"""
for fname in infer_file_list:
with open(fname, "r") as fin:
for line in fin:
one_data = (line.strip('\n')).split('\t')
input_emb = one_data[0].split(' ')
yield [input_emb]
import paddle
batch_iter = paddle.batch(local_iter, batch)
return batch_iter
def generate_sample(self, line):
"""
Read the data line by line and process it as a dictionary
"""
def iterator():
"""
This function needs to be implemented by the user, based on data format
"""
features = (line.strip('\n')).split('\t')
input_emb = features[0].split(' ')
item_label = [features[1]]
feature_name = ["input_emb","item_label"]
yield zip(feature_name, [input_emb] + [item_label])
return iterator
if __name__ == "__main__":
d = TDMDataset()
d.run_from_stdin()
......@@ -25,7 +25,7 @@ import paddle
import paddle.fluid as fluid
from args import print_arguments, parse_args
from utils import tdm_sampler_prepare, tdm_child_prepare, save_item_emb, gen_tree_main
from utils import tdm_sampler_prepare, tdm_child_prepare
from train_network import TdmTrainNet
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s")
......
0,0,0,1,2
0,1,0,3,4
0,1,0,5,6
0,2,1,7,8
0,2,1,9,10
0,2,2,11,12
0,2,2,13,0
0,3,3,14,15
0,3,3,16,17
0,3,4,18,19
0,3,4,20,21
0,3,5,22,23
0,3,5,24,25
12,3,6,0,0
0,4,7,0,0
1,4,7,0,0
2,4,8,0,0
3,4,8,0,0
4,4,9,0,0
5,4,9,0,0
6,4,10,0,0
7,4,10,0,0
8,4,11,0,0
9,4,11,0,0
10,4,12,0,0
11,4,12,0,0
\ No newline at end of file
......@@ -19,7 +19,7 @@ import math
import argparse
import numpy as np
import paddle.fluid as fluid
from utils import tdm_sampler_prepare, tdm_child_prepare, tdm_warm_start_prepare, tdm_item_rerank, trace_var
from utils import tdm_sampler_prepare, tdm_child_prepare, trace_var
class TdmTrainNet(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册