train.py 1.8 KB
Newer Older
O
overlordmax 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
import numpy as np
import os
import paddle.fluid as fluid
from net import wide_deep
import logging
import paddle
import args
import utils
import time

logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger("fluid")
logger.setLevel(logging.INFO)

def train(args, train_data_path):
    wide_deep_model = wide_deep()
    inputs = wide_deep_model.input_data()
    train_data_generator = utils.CriteoDataset()
    train_reader = paddle.batch(train_data_generator.train(train_data_path), batch_size=args.batch_size)
    
    loss, acc, auc, batch_auc, auc_states  = wide_deep_model.model(inputs, args.hidden1_units, args.hidden2_units, args.hidden3_units)
    optimizer = fluid.optimizer.AdagradOptimizer(learning_rate=0.01)
    optimizer.minimize(loss)
    
    place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    feeder = fluid.DataFeeder(feed_list=inputs, place=place)
    
    for epoch in range(args.epochs):
        for batch_id, data in enumerate(train_reader()):
            begin = time.time()
            loss_val, acc_val, auc_val = exe.run(program=fluid.default_main_program(),
                    feed=feeder.feed(data),
                    fetch_list=[loss.name, acc.name, auc.name],
                    return_numpy=True)
            end = time.time()
            logger.info("epoch:{}, batch_time:{:.5f}s, loss:{:.5f}, acc:{:.5f}, auc:{:.5f}".format(epoch, end-begin, np.array(loss_val)[0], 
                    np.array(acc_val)[0], np.array(auc_val)[0]))

        model_dir = os.path.join(args.model_dir, 'epoch_' + str(epoch + 1), "checkpoint")
        main_program = fluid.default_main_program()
        fluid.io.save(main_program,model_dir)
  
if __name__ == "__main__":
    args = args.parse_args()
    train(args, args.train_data_path)