提交 95f142b1 编写于 作者: X Xin Pan

resolve conflict

test=develop
上级 abdd9411
......@@ -59,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase):
exe = fluid.Executor(place)
exe.run(startup)
pe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=loss.name, main_program=main_program)
run_parallel_exe(main_program, pe, use_cuda, data, label, loss)
train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
loss_name=loss.name)
def run_parallel_exe_with_fetch(self, main, pe, use_cuda, data, label,
loss):
run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
def run_parallel_exe_with_fetch(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8):
np.random.seed(5)
while True:
......@@ -79,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase):
# conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main)
fetch_list = []
all_vars = main.global_block().vars
all_vars = compiled_program._program.global_block().vars
for k, v in all_vars.items():
if ('tmp' not in k) and (
......@@ -90,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase):
for batch_id, img_label in enumerate(get_data()):
img, l = img_label
train_inputs = {data.name: img, label.name: l}
ret = pe.run(fetch_list, feed=train_inputs, return_numpy=True)
ret = exe.run(compiled_program,
fetch_list=fetch_list,
feed=train_inputs,
return_numpy=True)
for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i]))
if batch_id == 2:
break
def run_parallel_exe_with_feed(self, main, pe, use_cuda, data, label, loss):
def run_parallel_exe_with_feed(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8):
np.random.seed(5)
while True:
......@@ -115,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase):
reader = feeder.decorate_reader(get_data, multi_devices=True)
for batch_id, data in enumerate(reader()):
loss_np = exe.run(train_cp, feed=data, fetch_list=[loss.name])[0]
loss_np = exe.run(compiled_program,
feed=data,
fetch_list=[loss.name])[0]
print(batch_id, loss_np)
if batch_id == 2:
break
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册