From e3fb25b6faa00d319f9a2065dcfefc795a30d772 Mon Sep 17 00:00:00 2001 From: tangwei Date: Fri, 10 Apr 2020 12:44:30 +0800 Subject: [PATCH] code fix --- fleetrec/examples/ctr-dnn_train_cluster.yaml | 4 ++-- fleetrec/examples/ctr-dnn_train_single.yaml | 4 ++-- fleetrec/models/ctr_dnn/model.py | 2 +- fleetrec/trainer/transpiler_trainer.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/fleetrec/examples/ctr-dnn_train_cluster.yaml b/fleetrec/examples/ctr-dnn_train_cluster.yaml index 8eb9eb1e..53f03253 100644 --- a/fleetrec/examples/ctr-dnn_train_cluster.yaml +++ b/fleetrec/examples/ctr-dnn_train_cluster.yaml @@ -41,8 +41,8 @@ train: reader: mode: "dataset" batch_size: 2 - pipe_command: "python /paddle/fleetrec/models/ctr_dnn/dataset.py" - train_data_path: "/paddle/fleetrec/models/ctr_dnn/data/train" + pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" + train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" model: models: "fleetrec.models.ctr_dnn.model" diff --git a/fleetrec/examples/ctr-dnn_train_single.yaml b/fleetrec/examples/ctr-dnn_train_single.yaml index 4091c78e..985adebe 100644 --- a/fleetrec/examples/ctr-dnn_train_single.yaml +++ b/fleetrec/examples/ctr-dnn_train_single.yaml @@ -35,8 +35,8 @@ train: reader: mode: "dataset" batch_size: 2 - pipe_command: "python /paddle/fleetrec/models/ctr_dnn/dataset.py" - train_data_path: "/paddle/fleetrec/models/ctr_dnn/data/train" + pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py" + train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train" model: models: "fleetrec.models.ctr_dnn.model" diff --git a/fleetrec/models/ctr_dnn/model.py b/fleetrec/models/ctr_dnn/model.py index c6b39971..f65a60e8 100644 --- a/fleetrec/models/ctr_dnn/model.py +++ b/fleetrec/models/ctr_dnn/model.py @@ -21,7 +21,7 @@ from fleetrec.models.base import Model class Train(Model): def __init__(self, config): - super().__init__(config) + Model.__init__(self, config) self.namespace = "train.model" def input(self): diff --git a/fleetrec/trainer/transpiler_trainer.py b/fleetrec/trainer/transpiler_trainer.py index 2c6e6802..3d1fa934 100644 --- a/fleetrec/trainer/transpiler_trainer.py +++ b/fleetrec/trainer/transpiler_trainer.py @@ -29,7 +29,7 @@ class TranspileTrainer(Trainer): def __init__(self, config=None): Trainer.__init__(self, config) self.processor_register() - + self.model = None self.inference_models = [] self.increment_models = [] @@ -115,7 +115,7 @@ class TranspileTrainer(Trainer): models = envs.get_global_env("train.model.models") model_package = __import__(models, globals(), locals(), models.split(".")) train_model = getattr(model_package, 'Train') - self.model = train_model() + self.model = train_model(None) context['status'] = 'init_pass' def init(self, context): -- GitLab