未验证 提交 4b8d65ab 编写于 作者: C chengduo 提交者: GitHub

Add return_numpy for PE (#11792)

上级 59837ff1
...@@ -78,6 +78,8 @@ def as_numpy(tensor): ...@@ -78,6 +78,8 @@ def as_numpy(tensor):
Returns: Returns:
numpy.ndarray numpy.ndarray
""" """
if isinstance(tensor, core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list): if isinstance(tensor, list):
return [as_numpy(t) for t in tensor] return [as_numpy(t) for t in tensor]
assert isinstance(tensor, core.LoDTensor) assert isinstance(tensor, core.LoDTensor)
......
...@@ -160,7 +160,7 @@ class ParallelExecutor(object): ...@@ -160,7 +160,7 @@ class ParallelExecutor(object):
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
""" """
Run a parallel executor with fetch_list. Run a parallel executor with fetch_list.
...@@ -196,6 +196,8 @@ class ParallelExecutor(object): ...@@ -196,6 +196,8 @@ class ParallelExecutor(object):
to each device. Default None. to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility. feed_dict: Alias for feed parameter, for backward compatibility.
This parameter has been deprecated. Default None. This parameter has been deprecated. Default None.
return_numpy(bool): Whether converts the fetched tensor to numpy.
Default: True.
Returns: Returns:
List: The fetched result list. List: The fetched result list.
...@@ -270,6 +272,9 @@ class ParallelExecutor(object): ...@@ -270,6 +272,9 @@ class ParallelExecutor(object):
if self.is_dist: if self.is_dist:
self.bcast_params() self.bcast_params()
if return_numpy:
return executor.as_numpy(arr)
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def bcast_params(self): def bcast_params(self):
......
...@@ -81,7 +81,6 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -81,7 +81,6 @@ class TestParallelExecutorBase(unittest.TestCase):
begin = time.time() begin = time.time()
first_loss, = run_executor( first_loss, = run_executor(
exe=exe, feed=feed_dict, fetch_list=[loss.name]) exe=exe, feed=feed_dict, fetch_list=[loss.name])
first_loss = np.array(first_loss)
for i in xrange(iter): for i in xrange(iter):
run_executor(exe=exe, feed=feed_dict, fetch_list=[]) run_executor(exe=exe, feed=feed_dict, fetch_list=[])
...@@ -94,8 +93,6 @@ class TestParallelExecutorBase(unittest.TestCase): ...@@ -94,8 +93,6 @@ class TestParallelExecutorBase(unittest.TestCase):
print "%.4f Instance per second" % ( print "%.4f Instance per second" % (
(batch_size * iter + 2) / (end - begin)) (batch_size * iter + 2) / (end - begin))
last_loss = np.array(last_loss)
print first_loss, last_loss print first_loss, last_loss
# self.assertGreater(first_loss[0], last_loss[0]) # self.assertGreater(first_loss[0], last_loss[0])
return first_loss, last_loss return first_loss, last_loss
...@@ -169,9 +169,8 @@ class TestCRFModel(unittest.TestCase): ...@@ -169,9 +169,8 @@ class TestCRFModel(unittest.TestCase):
data = train_data() data = train_data()
for i in xrange(10): for i in xrange(10):
cur_batch = next(data) cur_batch = next(data)
print map(np.array, print pe.run(feed=feeder.feed(cur_batch),
pe.run(feed=feeder.feed(cur_batch), fetch_list=[avg_cost.name])[0]
fetch_list=[avg_cost.name]))[0]
@unittest.skip(reason="CI hangs") @unittest.skip(reason="CI hangs")
def test_update_sparse_parameter_all_reduce(self): def test_update_sparse_parameter_all_reduce(self):
......
...@@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase): ...@@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase):
fetch_list.append(k) fetch_list.append(k)
for data in train_inputs: for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data)) ret = pe.run(fetch_list,
feed=feeder.feed(data),
return_numpy=True)
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))
...@@ -128,7 +130,7 @@ class TestFeedParallel(unittest.TestCase): ...@@ -128,7 +130,7 @@ class TestFeedParallel(unittest.TestCase):
use_cuda=use_cuda, loss_name=loss.name, main_program=main) use_cuda=use_cuda, loss_name=loss.name, main_program=main)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
loss_np = np.array(pe.run(feed=data, fetch_list=[loss.name])[0]) loss_np = pe.run(feed=data, fetch_list=[loss.name])[0]
print batch_id, loss_np print batch_id, loss_np
if batch_id == 2: if batch_id == 2:
break break
......
...@@ -70,10 +70,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase): ...@@ -70,10 +70,9 @@ class ParallelExecutorTestingDuringTraining(unittest.TestCase):
for i in xrange(5): for i in xrange(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) test_loss, = test_exe.run([loss.name], feed=feed_dict)
test_loss = np.array(test_loss)
train_loss, = train_exe.run([loss.name], feed=feed_dict) train_loss, = train_exe.run([loss.name], feed=feed_dict)
train_loss = np.array(train_loss)
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
train_loss, test_loss, atol=1e-8), train_loss, test_loss, atol=1e-8),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册