diff --git a/fluid/recommendation/ctr/network_conf.py b/fluid/recommendation/ctr/network_conf.py index 827bc72d26cf5b7ef8368477cca01898c09f0fa8..c3b596ae2e19949c99ba1a0085558a5668b64f7f 100644 --- a/fluid/recommendation/ctr/network_conf.py +++ b/fluid/recommendation/ctr/network_conf.py @@ -34,9 +34,9 @@ def DeepFM(factor_size, infer=False): cost = fluid.layers.cross_entropy(input=predict, label=label) avg_cost = fluid.layers.reduce_sum(cost) accuracy = fluid.layers.accuracy(input=predict, label=label) - auc_var, cur_auc_var, auc_states = fluid.layers.auc(input=predict, label=label, num_thresholds=2**12) + auc_var, batch_auc_var, auc_states = fluid.layers.auc(input=predict, label=label, num_thresholds=2**12, slide_steps=20) data_list.append(label) - return avg_cost, data_list + return avg_cost, data_list, auc_var, batch_auc_var else: return predict, data_list diff --git a/fluid/recommendation/ctr/train.py b/fluid/recommendation/ctr/train.py index d7a6ff08b64130895a841b457527dcbe2ce94992..6b40d0bfd41c115c23c34c5e74dcc96158329327 100644 --- a/fluid/recommendation/ctr/train.py +++ b/fluid/recommendation/ctr/train.py @@ -55,7 +55,7 @@ def train(): if not os.path.isdir(args.model_output_dir): os.mkdir(args.model_output_dir) - loss, data_list = DeepFM(args.factor_size) + loss, data_list, auc_var, batch_auc_var = DeepFM(args.factor_size) optimizer = fluid.optimizer.Adam(learning_rate=1e-4) optimize_ops, params_grads = optimizer.minimize(loss) @@ -71,13 +71,14 @@ def train(): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) - for data in train_reader(): - loss_var = exe.run( - fluid.default_main_program(), - feed=feeder.feed(data), - fetch_list=[loss] - ) - print(loss_var) + for pass_id in range(args.num_passes): + for data in train_reader(): + loss_val, auc_val, batch_auc_val = exe.run( + fluid.default_main_program(), + feed=feeder.feed(data), + fetch_list=[loss, auc_var, batch_auc_var] + ) + print('loss :' + str(loss_val) + " auc : " + str(auc_val) + " batch_auc : " + str(batch_auc_val)) if __name__ == '__main__':