未验证 提交 6cfc0c14 编写于 作者: D dzhwinter 提交者: GitHub

"polish code" (#9318)

* "polish code"

* "fix ci"

* "fix ci"

* "done"
上级 b55dc9a0
......@@ -48,8 +48,7 @@ def as_numpy(tensor):
assert isinstance(tensor, core.LoDTensor)
lod = tensor.lod()
if len(lod) > 0:
raise RuntimeError(
"Some of your featched tensors hold LoD information. \
raise RuntimeError("Some of your fetched tensors hold LoD information. \
They can not be completely cast to Python ndarray. \
Please set the parameter 'return_numpy' as 'False' to \
return LoDTensor itself directly.")
......@@ -180,59 +179,23 @@ def get_program_cache_key(feed, fetch_list):
class Executor(object):
def __init__(self, places):
if not isinstance(places, list) and not isinstance(places, tuple):
places = [places]
act_places = []
for each in places:
def __init__(self, place):
self.place = place
p = core.Place()
p.set_place(each)
act_places.append(p)
# TODO(dzhwinter) : only use the first place
self.executor = core.Executor(act_places[0])
self.places = places
p.set_place(place)
self.executor = core.Executor(p)
self.program_caches = dict()
def aslodtensor(self, data):
def accumulate(data):
if not isinstance(data, list):
return 1
return sum([accumulate(sub) for sub in data])
def parselod(data):
seq_lens = [accumulate(seq) for seq in data]
cur_len = 0
lod = [cur_len]
for l in seq_lens:
cur_len += l
lod.append(cur_len)
return lod
assert len(self.places) != 0
if not isinstance(data, list):
# pure tensor case
tensor = core.LoDTensor()
tensor.set(data, self.places[0])
return tensor
else:
raise RuntimeError("Current implementation lacks unittests")
# lodtensor case
lod = []
if not isinstance(data[0], list):
lod.append(parselod(data))
flattened_data = np.concatenate(data, axis=0).astype("int64")
else:
while isinstance(data[0], list):
lod.append(parselod(seq))
flattened_data = [item for seq in data for item in seq]
data = flattened_data
flattened_data = np.concatenate(data, axis=0).astype("int64")
flattened_data = flattened_data.reshape([len(flattened_data), 1])
def as_lodtensor(self, data):
if isinstance(data, list):
raise RuntimeError("Some of your feed data hold LoD information. \
They can not be completely cast from a list of Python \
ndarray to LoDTensor. Please convert data to LoDTensor \
directly before feeding the data.\
")
# single tensor case
tensor = core.LoDTensor()
tensor.set(flattened_data, self.places[0])
tensor.set_lod(lod)
tensor.set(data, self.place)
return tensor
def _get_program_cache(self, program_cache_key):
......@@ -293,7 +256,7 @@ class Executor(object):
feed_target_name = op.desc.output('Out')[0]
cur_feed = feed[feed_target_name]
if not isinstance(cur_feed, core.LoDTensor):
cur_feed = self.aslodtensor(cur_feed)
cur_feed = self.as_lodtensor(cur_feed)
idx = op.desc.attr('col')
core.set_feed_variable(scope, cur_feed, feed_var_name, idx)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册