diff --git a/models/ctr_dnn/model.py b/models/ctr_dnn/model.py index d08be491fa8cf8c33ad969889841b8bef4143461..e62ef34c40b00a576605ad85f0a48447666f35fa 100644 --- a/models/ctr_dnn/model.py +++ b/models/ctr_dnn/model.py @@ -124,6 +124,12 @@ class Train(object): return self.metrics + def metric_extras(self): + self.metric_vars = [self.metrics[0]] + self.metric_alias = ["AUC"] + self.fetch_interval_batchs = 10 + return (self.metric_vars, self.metric_alias, self.fetch_interval_batchs) + def optimizer(self): learning_rate = envs.get_global_env("hyper_parameters.learning_rate", None ,self.namespace) optimizer = fluid.optimizer.Adam(learning_rate, lazy_mode=True) diff --git a/trainer/single_train.py b/trainer/single_train.py index 594574f15fbaff209bd1ded4f904aa64bedb811d..49b76350f6da4b5f02076b85563c1ea1ccb86619 100644 --- a/trainer/single_train.py +++ b/trainer/single_train.py @@ -65,6 +65,7 @@ class SingleTrainer(Trainer): self.model.input() self.model.net() self.metrics = self.model.metrics() + self.metric_extras = self.model.metric_extras() loss = self.model.avg_loss() optimizer = self.model.optimizer() @@ -80,10 +81,10 @@ class SingleTrainer(Trainer): context['is_exit'] = True def infer(self, context): - print("Need to be implement") context['is_exit'] = True def terminal(self, context): + print("clean up and exit") context['is_exit'] = True @@ -162,16 +163,14 @@ class SingleTrainerWithDataset(SingleTrainer): epochs = envs.get_global_env("train.epochs") - print("fetch_list: {}".format(len(self.metrics))) - for i in range(epochs): self.exe.train_from_dataset(program=fluid.default_main_program(), dataset=dataset, - fetch_list=self.metrics, - fetch_info=["auc ", "batch auc"], - print_period=1) + fetch_list=self.metric_extras[0], + fetch_info=self.metric_extras[1], + print_period=self.metric_extras[2]) context['status'] = 'infer_pass' -def infer(self, context): - context['status'] = 'terminal_pass' + def infer(self, context): + context['status'] = 'terminal_pass' diff --git a/utils/envs.py b/utils/envs.py index b4ea34575dad0b5c7fda5a27224bec37abc85b40..a0cd40bc18bc508daf5f07bcde7f041a88277aa6 100644 --- a/utils/envs.py +++ b/utils/envs.py @@ -53,7 +53,7 @@ def pretty_print_envs(): max_k = max(max_k, len(k)) max_v = max(max_v, len(str(v))) - h_format = "{{:^{}s}}{{:<{}s}}\n".format(max_k, max_v) + h_format = "{{:^{}s}}{}{{:<{}s}}\n".format(max_k, " " * spacing, max_v) l_format = "{{:<{}s}}{{}}{{:<{}s}}\n".format(max_k, max_v) length = max_k + max_v + spacing