train.py 3.1 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12
import numpy as np
import os
import paddle.fluid as fluid
import logging
import args
import random
import time
from evaluator import BiRNN

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)
O
overlordmax 已提交
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
user_id = 0
class Dataset(object):
    def _reader_creator(self):
        def reader():
            global user_id
            user_slot_name = []
            for j in range(args.batch_size):
                user_slot_name.append([user_id])
                user_id += 1
            
            item_slot_name = np.random.randint(args.item_vocab, size=(args.batch_size, args.item_len)).tolist()
            lenght = [args.item_len]*args.batch_size
            label = np.random.randint(2, size=(args.batch_size, args.item_len)).tolist()
            output = []
            output.append(user_slot_name)
            output.append(item_slot_name)
            output.append(lenght)
            output.append(label)

            yield output
        return reader
    def get_train_data(self):
        return self._reader_creator()
O
overlordmax 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48

def train(args):

    model = BiRNN()
    inputs = model.input_data(args.item_len)
    loss, auc_val, batch_auc, auc_states = model.net(inputs, args.hidden_size, args.batch_size*args.sample_size, args.item_vocab, args.embd_dim)

    optimizer = fluid.optimizer.Adam(learning_rate=args.base_lr, epsilon=1e-4)
    optimizer.minimize(loss)

    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
O
overlordmax 已提交
49 50 51 52 53
    
    train_data_generator = Dataset()
    train_reader = fluid.io.batch(train_data_generator.get_train_data(), batch_size=args.batch_size)
    loader = fluid.io.DataLoader.from_generator(feed_list=inputs, capacity=args.batch_size, iterable=True)
    loader.set_sample_list_generator(train_reader, places=place)
Y
yudongxu(许煜东) 已提交
54 55 56 57 58 59 60 61 62 63 64
    for epoch in range(args.epochs):
        for i in range(args.sample_size):
            for batch_id, data in enumerate(loader()):
                begin = time.time()
                loss_val, auc = exe.run(program=fluid.default_main_program(),
                        feed=data,
                        fetch_list=[loss.name, auc_val],
                        return_numpy=True)
                end = time.time()
                logger.info("epoch: {}, batch_id: {}, batch_time: {:.5f}s, loss: {:.5f}, auc: {:.5f}".format(
                    epoch, batch_id, end-begin, float(np.array(loss_val)), float(np.array(auc))))
O
overlordmax 已提交
65
        
O
overlordmax 已提交
66 67 68 69
    #save model
    model_dir = os.path.join(args.model_dir, 'epoch_' + str(1), "checkpoint")
    main_program = fluid.default_main_program()
    fluid.save(main_program, model_dir)
O
overlordmax 已提交
70 71 72

if __name__ == "__main__":
    args = args.parse_args()
O
overlordmax 已提交
73 74 75 76
    logger.info("use_gpu: {}, batch_size: {}, model_dir: {}, embd_dim: {}, hidden_size: {}, item_vocab: {}, user_vocab: {},\
    item_len: {}, sample_size: {}, base_lr: {}".format(args.use_gpu, args.batch_size, args.model_dir, args.embd_dim, 
     args.hidden_size, args.item_vocab, args.user_vocab, args.item_len, args.sample_size, args.base_lr))

O
overlordmax 已提交
77
    train(args)