提交 847d17db 编写于 作者: Y yaoxuefeng

fix bug

上级 51242677
...@@ -23,7 +23,6 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f ...@@ -23,7 +23,6 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import f
from fleetrec.core.trainer import Trainer from fleetrec.core.trainer import Trainer
from fleetrec.core.utils import envs from fleetrec.core.utils import envs
from fleetrec.core.utils import dataloader_instance from fleetrec.core.utils import dataloader_instance
import fleetrec.core.din_reader as din_reader
class TranspileTrainer(Trainer): class TranspileTrainer(Trainer):
...@@ -48,7 +47,6 @@ class TranspileTrainer(Trainer): ...@@ -48,7 +47,6 @@ class TranspileTrainer(Trainer):
reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader") reader_class = envs.lazy_instance_by_fliename(reader_class, "TrainReader")
reader_ins = reader_class(self._config_yaml) reader_ins = reader_class(self._config_yaml)
if hasattr(reader_ins,'generate_batch_from_trainfiles'): if hasattr(reader_ins,'generate_batch_from_trainfiles'):
print("++++++++hieehi+++++++++")
dataloader.set_sample_list_generator(reader) dataloader.set_sample_list_generator(reader)
else: else:
dataloader.set_sample_generator(reader, batch_size) dataloader.set_sample_generator(reader, batch_size)
......
...@@ -62,4 +62,4 @@ def dataloader(readerclass, train, yaml_file): ...@@ -62,4 +62,4 @@ def dataloader(readerclass, train, yaml_file):
if hasattr(reader, 'generate_batch_from_trainfiles'): if hasattr(reader, 'generate_batch_from_trainfiles'):
return gen_batch_reader() return gen_batch_reader()
return reader.generate_dataloader_batch(files) return gen_reader
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册