From a4815f33aa55dff6eebd98dcceadb6528226c931 Mon Sep 17 00:00:00 2001 From: Superjom Date: Fri, 14 Jul 2017 13:25:35 +0800 Subject: [PATCH] fix infer bug --- ctr/infer.py | 2 +- ctr/network_conf.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/ctr/infer.py b/ctr/infer.py index 8ed79130..721c6b01 100644 --- a/ctr/infer.py +++ b/ctr/infer.py @@ -50,7 +50,7 @@ class CTRInferer(object): dnn_layer_dims, dnn_input_dim, lr_input_dim, - model_type=args.model_type, + model_type=ModelType(args.model_type), is_infer=True) # load parameter logger.info("load model parameters from %s" % param_path) diff --git a/ctr/network_conf.py b/ctr/network_conf.py index 1dc45610..6b38f57c 100644 --- a/ctr/network_conf.py +++ b/ctr/network_conf.py @@ -95,8 +95,9 @@ class CTRmodel(object): # use sigmoid function to approximate ctr rate, a float value between 0 and 1. act=paddle.activation.Sigmoid()) - self.train_cost = paddle.layer.multi_binary_label_cross_entropy_cost( - input=self.output, label=self.click) + if not self.is_infer: + self.train_cost = paddle.layer.multi_binary_label_cross_entropy_cost( + input=self.output, label=self.click) return self.output def _build_regression_model(self, dnn, lr): -- GitLab