提交 fd92c826 编写于 作者: T tangwei

bug fix

上级 eb04c187
...@@ -80,7 +80,7 @@ class ModelBase(object): ...@@ -80,7 +80,7 @@ class ModelBase(object):
self._data_var = [] self._data_var = []
self._fetch_interval = 20 self._fetch_interval = 20
def get_input(self): def get_inputs(self):
return self._data_var return self._data_var
def get_cost_op(self): def get_cost_op(self):
......
...@@ -52,9 +52,11 @@ class Model(ModelBase): ...@@ -52,9 +52,11 @@ class Model(ModelBase):
self.dense_input = dense_input() self.dense_input = dense_input()
self.label_input = label_input() self.label_input = label_input()
self._data_var.append(self.dense_input)
for input in self.sparse_inputs: for input in self.sparse_inputs:
self._data_var.append(input) self._data_var.append(input)
self._data_var.append(self.dense_input)
self._data_var.append(self.label_input) self._data_var.append(self.label_input)
def net(self): def net(self):
......
...@@ -38,7 +38,7 @@ class TranspileTrainer(Trainer): ...@@ -38,7 +38,7 @@ class TranspileTrainer(Trainer):
def _get_dataset(self): def _get_dataset(self):
namespace = "train.reader" namespace = "train.reader"
inputs = self.model.inputs() inputs = self.model.get_inputs()
threads = envs.get_global_env("train.threads", None) threads = envs.get_global_env("train.threads", None)
batch_size = envs.get_global_env("batch_size", None, namespace) batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace) reader_class = envs.get_global_env("class", None, namespace)
...@@ -115,7 +115,7 @@ class TranspileTrainer(Trainer): ...@@ -115,7 +115,7 @@ class TranspileTrainer(Trainer):
def instance(self, context): def instance(self, context):
models = envs.get_global_env("train.model.models") models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "TrainModel") model_class = envs.lazy_instance(models, "Model")
self.model = model_class(None) self.model = model_class(None)
context['status'] = 'init_pass' context['status'] = 'init_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册