提交 fd92c826 编写于 作者: T tangwei

bug fix

上级 eb04c187
......@@ -80,7 +80,7 @@ class ModelBase(object):
self._data_var = []
self._fetch_interval = 20
def get_input(self):
def get_inputs(self):
return self._data_var
def get_cost_op(self):
......
......@@ -52,9 +52,11 @@ class Model(ModelBase):
self.dense_input = dense_input()
self.label_input = label_input()
self._data_var.append(self.dense_input)
for input in self.sparse_inputs:
self._data_var.append(input)
self._data_var.append(self.dense_input)
self._data_var.append(self.label_input)
def net(self):
......
......@@ -38,7 +38,7 @@ class TranspileTrainer(Trainer):
def _get_dataset(self):
namespace = "train.reader"
inputs = self.model.inputs()
inputs = self.model.get_inputs()
threads = envs.get_global_env("train.threads", None)
batch_size = envs.get_global_env("batch_size", None, namespace)
reader_class = envs.get_global_env("class", None, namespace)
......@@ -115,7 +115,7 @@ class TranspileTrainer(Trainer):
def instance(self, context):
models = envs.get_global_env("train.model.models")
model_class = envs.lazy_instance(models, "TrainModel")
model_class = envs.lazy_instance(models, "Model")
self.model = model_class(None)
context['status'] = 'init_pass'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册