提交 6c167379 编写于 作者: W wuzewu

update run command

上级 c8193062
......@@ -82,7 +82,8 @@ class RunCommand(BaseCommand):
assert len(input_data_format) == len(expect_data_format)
for key, value in expect_data_format.items():
assert key in input_data_format
assert value == hub.DataType.type(input_data_format[key]['type'])
assert value['type'] == hub.DataType.type(
input_data_format[key]['type'])
# get data dict
origin_data = csv_reader.read(self.args.dataset)
......@@ -93,10 +94,11 @@ class RunCommand(BaseCommand):
map(type_reader.read, origin_data[value['key']]))
# run module with data
module(
sign_name=self.args.signature,
data=input_data,
**yaml_config['config'])
print(
module(
sign_name=self.args.signature,
data=input_data,
**yaml_config['config']))
command = RunCommand.instance()
......@@ -331,23 +331,47 @@ class Module:
module_info.map.data['summary'])
def __call__(self, sign_name, data, **kwargs):
def _get_reader_and_feeder(data_format, data, place):
def _reader():
nonlocal process_data
for item in zip(*process_data):
yield item
process_data = []
feed_name_list = []
for key in data_format:
process_data.append([value['processed'] for value in data[key]])
feed_name_list.append(data_format[key]['feed_key'])
feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
return _reader, feeder
feed_dict, fetch_dict, program = self.context(sign_name, for_test=True)
#TODO(wuzewu): more option
reader = self.processor.reader(
sign_name=sign_name, data_dict=data, **kwargs)
feed_name_list = list(
set([value.name for key, value in feed_dict.items()]))
fetch_list = list(set([value for key, value in fetch_dict.items()]))
with fluid.program_guard(program):
result = []
index = 0
place = fluid.CPUPlace()
exe = fluid.Executor(place=place)
feeder = fluid.DataFeeder(feed_list=feed_name_list, place=place)
data = self.processor.preprocess(
sign_name=sign_name, data_dict=data)
data_format = self.processor.data_format(sign_name=sign_name)
reader, feeder = _get_reader_and_feeder(data_format, data, place)
reader = paddle.batch(reader, batch_size=2)
for batch in reader():
data_out = exe.run(
feed=feeder.feed(batch),
fetch_list=fetch_list,
return_numpy=False)
self.processor.postprocess(sign_name, data_out, **kwargs)
sub_data = {
key: value[index:index + len(batch)]
for key, value in data.items()
}
result += self.processor.postprocess(sign_name, data_out,
sub_data, **kwargs)
index += len(batch)
return result
def context(self,
sign_name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册