From 405ec7ed6c1ca1a37c2a0f1edb59cd017ea31473 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 9 Apr 2020 02:12:28 +0000 Subject: [PATCH] add ctr metrics --- models/ctr_dnn/model.py | 6 ++++++ trainer/single_train.py | 15 +++++++-------- utils/envs.py | 2 +- 3 files changed, 14 insertions(+), 9 deletions(-) diff --git a/models/ctr_dnn/model.py b/models/ctr_dnn/model.py index d08be491..e62ef34c 100644 --- a/models/ctr_dnn/model.py +++ b/models/ctr_dnn/model.py @@ -124,6 +124,12 @@ class Train(object): return self.metrics + def metric_extras(self): + self.metric_vars = [self.metrics[0]] + self.metric_alias = ["AUC"] + self.fetch_interval_batchs = 10 + return (self.metric_vars, self.metric_alias, self.fetch_interval_batchs) + def optimizer(self): learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None ,self.namespace) optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) diff --git a/trainer/single_train.py b/trainer/single_train.py index 594574f1..49b76350 100644 --- a/trainer/single_train.py +++ b/trainer/single_train.py @@ -65,6 +65,7 @@ class SingleTrainer(Trainer): self.model.input() self.model.net() self.metrics = self.model.metrics() + self.metric_extras = self.model.metric_extras() loss = self.model.avg_loss() optimizer = self.model.optimizer() @@ -80,10 +81,10 @@ class SingleTrainer(Trainer): context['is_exit'] = True def infer(self, context): - print("Need to be implement") context['is_exit'] = True def terminal(self, context): + print("clean up and exit") context['is_exit'] = True @@ -162,16 +163,14 @@ class SingleTrainerWithDataset(SingleTrainer): epochs = envs.get_global_env("train.epochs") - print("fetch_list: {}".format(len(self.metrics))) - for i in range(epochs): self.exe.train_from_dataset(program=fluid.default_main_program(), dataset=dataset, - fetch_list=self.metrics, - fetch_info=["auc ", "batch auc"], - print_period=1) + fetch_list=self.metric_extras[0], + fetch_info=self.metric_extras[1], + print_period=self.metric_extras[2]) context['status'] = 'infer_pass' -def infer(self, context): - context['status'] = 'terminal_pass' + def infer(self, context): + context['status'] = 'terminal_pass' diff --git a/utils/envs.py b/utils/envs.py index b4ea3457..a0cd40bc 100644 --- a/utils/envs.py +++ b/utils/envs.py @@ -53,7 +53,7 @@ def pretty_print_envs(): max_k = max(max_k, len(k)) max_v = max(max_v, len(str(v))) - h_format = "{{:^{}s}}{{:<{}s}}\n".format(max_k, max_v) + h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v) l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v) length = max_k + max_v + spacing -- GitLab