未验证 提交 57adbb6c 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #58 from wangxiao1021/api

fix bugs
...@@ -168,10 +168,10 @@ class Trainer(object): ...@@ -168,10 +168,10 @@ class Trainer(object):
if not self._lock_prog: if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog): with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) net_inputs = reader_helper.create_net_inputs(input_attrs, is_async=False)
bb_output_vars = backbone.build(net_inputs) bb_output_vars = backbone.build(net_inputs)
else: else:
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) net_inputs = reader_helper.create_net_inputs(input_attrs, is_async=False)
bb_output_vars = backbone.build(net_inputs) bb_output_vars = backbone.build(net_inputs)
self._net_inputs = net_inputs self._net_inputs = net_inputs
assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys()) assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())
......
...@@ -117,7 +117,7 @@ def _zero_batch_x(attrs, batch_size): ...@@ -117,7 +117,7 @@ def _zero_batch_x(attrs, batch_size):
return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs] return [np.zeros(shape=shape, dtype=dtype) for shape, dtype in pos_attrs]
def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n_prefetch=1): def create_net_inputs(input_attrs, is_async=False, iterator_fn=None, dev_count=1, n_prefetch=1):
inputs = [] inputs = []
ret = {} ret = {}
for name, shape, dtype in input_attrs: for name, shape, dtype in input_attrs:
...@@ -125,7 +125,7 @@ def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n ...@@ -125,7 +125,7 @@ def create_net_inputs(input_attrs, async=False, iterator_fn=None, dev_count=1, n
ret[name] = p ret[name] = p
inputs.append(p) inputs.append(p)
if async: if is_async:
assert iterator_fn is not None, "iterator_fn is needed for building async input layer." assert iterator_fn is not None, "iterator_fn is needed for building async input layer."
reader = fluid.io.PyReader(inputs, capacity=dev_count, iterable=False) reader = fluid.io.PyReader(inputs, capacity=dev_count, iterable=False)
reader.decorate_batch_generator(iterator_fn) reader.decorate_batch_generator(iterator_fn)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册