提交 e3fb25b6 编写于 作者: T tangwei

code fix

上级 bfcd53da
...@@ -41,8 +41,8 @@ train: ...@@ -41,8 +41,8 @@ train:
reader: reader:
mode: "dataset" mode: "dataset"
batch_size: 2 batch_size: 2
pipe_command: "python /paddle/fleetrec/models/ctr_dnn/dataset.py" pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/fleetrec/models/ctr_dnn/data/train" train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train"
model: model:
models: "fleetrec.models.ctr_dnn.model" models: "fleetrec.models.ctr_dnn.model"
......
...@@ -35,8 +35,8 @@ train: ...@@ -35,8 +35,8 @@ train:
reader: reader:
mode: "dataset" mode: "dataset"
batch_size: 2 batch_size: 2
pipe_command: "python /paddle/fleetrec/models/ctr_dnn/dataset.py" pipe_command: "python /paddle/eleps/fleetrec/models/ctr_dnn/dataset.py"
train_data_path: "/paddle/fleetrec/models/ctr_dnn/data/train" train_data_path: "/paddle/eleps/fleetrec/models/ctr_dnn/data/train"
model: model:
models: "fleetrec.models.ctr_dnn.model" models: "fleetrec.models.ctr_dnn.model"
......
...@@ -21,7 +21,7 @@ from fleetrec.models.base import Model ...@@ -21,7 +21,7 @@ from fleetrec.models.base import Model
class Train(Model): class Train(Model):
def __init__(self, config): def __init__(self, config):
super().__init__(config) Model.__init__(self, config)
self.namespace = "train.model" self.namespace = "train.model"
def input(self): def input(self):
......
...@@ -29,7 +29,7 @@ class TranspileTrainer(Trainer): ...@@ -29,7 +29,7 @@ class TranspileTrainer(Trainer):
def __init__(self, config=None): def __init__(self, config=None):
Trainer.__init__(self, config) Trainer.__init__(self, config)
self.processor_register() self.processor_register()
self.model = None
self.inference_models = [] self.inference_models = []
self.increment_models = [] self.increment_models = []
...@@ -115,7 +115,7 @@ class TranspileTrainer(Trainer): ...@@ -115,7 +115,7 @@ class TranspileTrainer(Trainer):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_package = __import__(models, globals(), locals(), models.split(".")) model_package = __import__(models, globals(), locals(), models.split("."))
train_model = getattr(model_package, 'Train') train_model = getattr(model_package, 'Train')
self.model = train_model() self.model = train_model(None)
context['status'] = 'init_pass' context['status'] = 'init_pass'
def init(self, context): def init(self, context):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册