diff --git a/fleetrec/models/base.py b/fleetrec/models/base.py index 23269674828b234b8ef8867d4eb2cdcaec2f310d..06188186c74fb1b140fc6222e9eea5d7be8025a0 100644 --- a/fleetrec/models/base.py +++ b/fleetrec/models/base.py @@ -80,7 +80,7 @@ class ModelBase(object): self._data_var = [] self._fetch_interval = 20 - def get_input(self): + def get_inputs(self): return self._data_var def get_cost_op(self): diff --git a/fleetrec/models/ctr_dnn/model.py b/fleetrec/models/ctr_dnn/model.py index df96eb6fb25e7c91f8352175921a52fcac5c560d..f771eaf5780ad71d3c8b989feb05c51c135c6253 100644 --- a/fleetrec/models/ctr_dnn/model.py +++ b/fleetrec/models/ctr_dnn/model.py @@ -52,9 +52,11 @@ class Model(ModelBase): self.dense_input = dense_input() self.label_input = label_input() + self._data_var.append(self.dense_input) + for input in self.sparse_inputs: self._data_var.append(input) - self._data_var.append(self.dense_input) + self._data_var.append(self.label_input) def net(self): diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 558c8a96ccabc1fbe215dae05cb10af7b554d9fc..06ddc63ca12146d6efda16dba90216a941702737 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -38,7 +38,7 @@ class TranspileTrainer(Trainer): def _get_dataset(self): namespace = "train.reader" - inputs = self.model.inputs() + inputs = self.model.get_inputs() threads = envs.get_global_env("train.threads", None) batch_size = envs.get_global_env("batch_size", None, namespace) reader_class = envs.get_global_env("class", None, namespace) @@ -115,7 +115,7 @@ class TranspileTrainer(Trainer): def instance(self, context): models = envs.get_global_env("train.model.models") - model_class = envs.lazy_instance(models, "TrainModel") + model_class = envs.lazy_instance(models, "Model") self.model = model_class(None) context['status'] = 'init_pass'