提交 05b45e1f 编写于 作者: Y Yu Yang

Remove reader logic

上级 797e89ec
...@@ -21,30 +21,16 @@ class Inference(object): ...@@ -21,30 +21,16 @@ class Inference(object):
self.__gradient_machine__ = gm self.__gradient_machine__ = gm
self.__data_types__ = topo.data_type() self.__data_types__ = topo.data_type()
def iter_infer(self, input=None, batch_size=None, reader=None, def iter_infer(self, input, feeding=None):
feeding=None):
feeder = DataFeeder(self.__data_types__, feeding) feeder = DataFeeder(self.__data_types__, feeding)
if reader is None: batch_size = len(input)
assert input is not None and isinstance(input, collections.Iterable)
if not isinstance(input, collections.Iterable): def __reader_impl__():
raise TypeError("When reader is None, input should be whole " for each_sample in input:
"inference data and should be iterable") yield each_sample
if batch_size is None: reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
if not hasattr(input, '__len__'):
raise ValueError("Should set batch size when input data "
"don't contain length.")
batch_size = len(input)
def __reader_impl__():
for each_sample in input:
yield each_sample
reader = minibatch.batch(__reader_impl__, batch_size=batch_size)
else:
if input is not None:
raise ValueError("User should set either input or reader, "
"should not set them both.")
self.__gradient_machine__.start() self.__gradient_machine__.start()
for data_batch in reader(): for data_batch in reader():
yield self.__gradient_machine__.forwardTest(feeder(data_batch)) yield self.__gradient_machine__.forwardTest(feeder(data_batch))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册