未验证 提交 8652a899 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #15279 from panyx0718/api

convert all tests to new CompiledProgram API
...@@ -382,9 +382,11 @@ class Executor(object): ...@@ -382,9 +382,11 @@ class Executor(object):
""" """
Close this executor. Close this executor.
You can no long use this executor after calling this method. You can no longer use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to For the distributed training, this method would free the resource on PServers related to
the current Trainer. the current Trainer.
TODO(typhoonzero): Define "no longer use" meaning? Can user create
a new Executor for the same program and run?
TODO(panyx0718): Why ParallelExecutor doesn't have close? TODO(panyx0718): Why ParallelExecutor doesn't have close?
Example: Example:
...@@ -397,7 +399,7 @@ class Executor(object): ...@@ -397,7 +399,7 @@ class Executor(object):
self.executor.close() self.executor.close()
self._closed = True self._closed = True
def _run_parallel(self, scope, feed, fetch_list, fetch_var_name, def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name,
return_numpy): return_numpy):
if isinstance(feed, dict): if isinstance(feed, dict):
feed_tensor_dict = dict() feed_tensor_dict = dict()
...@@ -413,7 +415,7 @@ class Executor(object): ...@@ -413,7 +415,7 @@ class Executor(object):
self.executor.feed_and_split_tensor_into_local_scopes( self.executor.feed_and_split_tensor_into_local_scopes(
feed_tensor_dict) feed_tensor_dict)
elif isinstance(feed, list) or isinstance(feed, tuple): elif isinstance(feed, list) or isinstance(feed, tuple):
if len(feed) != len(self._places): if len(feed) != len(program._places):
raise ValueError( raise ValueError(
"Feed a list of tensor, the list should be the same size as places" "Feed a list of tensor, the list should be the same size as places"
) )
...@@ -428,7 +430,7 @@ class Executor(object): ...@@ -428,7 +430,7 @@ class Executor(object):
tensor = each[feed_name] tensor = each[feed_name]
if not isinstance(tensor, core.LoDTensor): if not isinstance(tensor, core.LoDTensor):
tmp = core.LoDTensor() tmp = core.LoDTensor()
tmp.set(tensor, self._places[i]) tmp.set(tensor, program._places[i])
tensor = tmp tensor = tmp
res_dict[feed_name] = tensor res_dict[feed_name] = tensor
res.append(res_dict) res.append(res_dict)
...@@ -462,7 +464,7 @@ class Executor(object): ...@@ -462,7 +464,7 @@ class Executor(object):
Args: Args:
program(Program|CompiledProgram): the program that need to run, program(Program|CompiledProgram): the program that need to run,
if not provided, then default_main_program will be used. if not provided, then default_main_program (not compiled) will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData} feed(dict): feed variable map, e.g. {"image": ImageData, "label": LabelData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list. fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator. feed_var_name(str): the name for the input variable of feed Operator.
...@@ -525,6 +527,7 @@ class Executor(object): ...@@ -525,6 +527,7 @@ class Executor(object):
self.executor = program._executor self.executor = program._executor
if program._is_data_parallel: if program._is_data_parallel:
return self._run_parallel( return self._run_parallel(
program,
scope=scope, scope=scope,
feed=feed, feed=feed,
fetch_list=fetch_list, fetch_list=fetch_list,
......
...@@ -22,6 +22,7 @@ import unittest ...@@ -22,6 +22,7 @@ import unittest
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2): def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
...@@ -57,19 +58,19 @@ def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2): ...@@ -57,19 +58,19 @@ def train(network, use_cuda, use_parallel_executor, batch_size=32, pass_num=2):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
train_cp = compiler.CompiledProgram(fluid.default_main_program())
if use_parallel_executor: if use_parallel_executor:
train_exe = fluid.ParallelExecutor( train_cp = train_cp.with_data_parallel(loss_name=cost.name)
use_cuda=use_cuda, loss_name=cost.name)
fetch_list = [cost.name] fetch_list = [cost.name]
else: else:
train_exe = exe
fetch_list = [cost] fetch_list = [cost]
for pass_id in six.moves.xrange(pass_num): for pass_id in six.moves.xrange(pass_num):
batch_id = 0 batch_id = 0
for data in reader(): for data in reader():
train_exe.run(feed=data, exe.run(train_cp,
fetch_list=fetch_list if batch_id % 4 == 0 else []) feed=data,
fetch_list=fetch_list if batch_id % 4 == 0 else [])
batch_id += 1 batch_id += 1
if batch_id > 16: if batch_id > 16:
break break
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import paddle.dataset.conll05 as conll05 import paddle.dataset.conll05 as conll05
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core import paddle.fluid.core as core
import unittest import unittest
import paddle import paddle
...@@ -157,10 +158,8 @@ class TestCRFModel(unittest.TestCase): ...@@ -157,10 +158,8 @@ class TestCRFModel(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main).with_data_parallel(
use_cuda=use_cuda, loss_name=avg_cost.name, build_strategy=build_strategy)
loss_name=avg_cost.name,
build_strategy=build_strategy)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=[ feed_list=[
...@@ -172,8 +171,9 @@ class TestCRFModel(unittest.TestCase): ...@@ -172,8 +171,9 @@ class TestCRFModel(unittest.TestCase):
data = train_data() data = train_data()
for i in range(10): for i in range(10):
cur_batch = next(data) cur_batch = next(data)
print(pe.run(feed=feeder.feed(cur_batch), print(exe.run(train_cp,
fetch_list=[avg_cost.name])[0]) feed=feeder.feed(cur_batch),
fetch_list=[avg_cost.name])[0])
def _new_build_strategy(self, use_reduce=False): def _new_build_strategy(self, use_reduce=False):
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import unittest import unittest
import logging import logging
import six import six
...@@ -36,21 +37,18 @@ class TestBase(unittest.TestCase): ...@@ -36,21 +37,18 @@ class TestBase(unittest.TestCase):
with fluid.program_guard(main_prog, startup_prog): with fluid.program_guard(main_prog, startup_prog):
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
loss = network_func() loss = network_func()
fluid.Executor( exe = fluid.Executor(
fluid.CUDAPlace(0) fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace())
if use_gpu else fluid.CPUPlace()).run(startup_prog) exe.run(startup_prog)
for _ in six.moves.xrange(iter): for _ in six.moves.xrange(iter):
exe_strategy = fluid.ExecutionStrategy() exe_strategy = fluid.ExecutionStrategy()
exe_strategy._dry_run = True exe_strategy._dry_run = True
exe_strategy.use_experimental_executor = use_experimental_executor exe_strategy.use_experimental_executor = use_experimental_executor
pe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main_prog).with_data_parallel(
use_cuda=use_gpu, loss_name=loss.name, exec_strategy=exe_strategy)
loss_name=loss.name,
main_program=main_prog,
exec_strategy=exe_strategy)
for _ in six.moves.xrange(iter_per_pe): for _ in six.moves.xrange(iter_per_pe):
pe.run([]) exe.run(train_cp)
class TestMNISTDryRun(TestBase): class TestMNISTDryRun(TestBase):
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import math import math
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core import paddle.fluid.core as core
import unittest import unittest
import numpy as np import numpy as np
...@@ -58,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -58,12 +59,13 @@ class TestFetchAndFeed(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup) exe.run(startup)
pe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main_program).with_data_parallel(
use_cuda=use_cuda, loss_name=loss.name, main_program=main_program) loss_name=loss.name)
run_parallel_exe(main_program, pe, use_cuda, data, label, loss)
def run_parallel_exe_with_fetch(self, main, pe, use_cuda, data, label, run_parallel_exe(train_cp, exe, use_cuda, data, label, loss)
loss):
def run_parallel_exe_with_fetch(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8): def get_data(batch_size=8):
np.random.seed(5) np.random.seed(5)
while True: while True:
...@@ -78,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -78,7 +80,7 @@ class TestFetchAndFeed(unittest.TestCase):
# conv2d_1.b_0@GRAD. Those variables should not be pruned. # conv2d_1.b_0@GRAD. Those variables should not be pruned.
# fluid.memory_optimize(main) # fluid.memory_optimize(main)
fetch_list = [] fetch_list = []
all_vars = main.global_block().vars all_vars = compiled_program._program.global_block().vars
for k, v in all_vars.items(): for k, v in all_vars.items():
if ('tmp' not in k) and ( if ('tmp' not in k) and (
...@@ -89,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -89,14 +91,18 @@ class TestFetchAndFeed(unittest.TestCase):
for batch_id, img_label in enumerate(get_data()): for batch_id, img_label in enumerate(get_data()):
img, l = img_label img, l = img_label
train_inputs = {data.name: img, label.name: l} train_inputs = {data.name: img, label.name: l}
ret = pe.run(fetch_list, feed=train_inputs, return_numpy=True) ret = exe.run(compiled_program,
fetch_list=fetch_list,
feed=train_inputs,
return_numpy=True)
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))
if batch_id == 2: if batch_id == 2:
break break
def run_parallel_exe_with_feed(self, main, pe, use_cuda, data, label, loss): def run_parallel_exe_with_feed(self, compiled_program, exe, use_cuda, data,
label, loss):
def get_data(batch_size=8): def get_data(batch_size=8):
np.random.seed(5) np.random.seed(5)
while True: while True:
...@@ -114,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase): ...@@ -114,7 +120,9 @@ class TestFetchAndFeed(unittest.TestCase):
reader = feeder.decorate_reader(get_data, multi_devices=True) reader = feeder.decorate_reader(get_data, multi_devices=True)
for batch_id, data in enumerate(reader()): for batch_id, data in enumerate(reader()):
loss_np = pe.run(feed=data, fetch_list=[loss.name])[0] loss_np = exe.run(compiled_program,
feed=data,
fetch_list=[loss.name])[0]
print(batch_id, loss_np) print(batch_id, loss_np)
if batch_id == 2: if batch_id == 2:
break break
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid import compiler
import numpy as np import numpy as np
import unittest import unittest
import os import os
...@@ -61,22 +62,21 @@ class TestPassBuilder(unittest.TestCase): ...@@ -61,22 +62,21 @@ class TestPassBuilder(unittest.TestCase):
exe.run(startup) exe.run(startup)
feed_dict = {'image': image, 'label': label} feed_dict = {'image': image, 'label': label}
train_exe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(main).with_data_parallel(
use_cuda=use_cuda, loss_name=loss.name, build_strategy=build_strategy)
test_cp = compiler.CompiledProgram(test_program).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
main_program=main, build_strategy=build_strategy,
build_strategy=build_strategy) share_vars_from=train_cp)
test_exe = fluid.ParallelExecutor(
use_cuda=use_cuda,
main_program=test_program,
share_vars_from=train_exe,
build_strategy=build_strategy)
for i in range(5): for i in range(5):
test_loss, = test_exe.run([loss.name], feed=feed_dict) _ = exe.run(train_cp, fetch_list=[loss.name], feed=feed_dict)
test_loss, = exe.run(test_cp,
train_loss, = train_exe.run([loss.name], feed=feed_dict) fetch_list=[loss.name],
feed=feed_dict)
train_loss = exe.run(train_cp,
fetch_list=[loss.name],
feed=feed_dict)
avg_test_loss_val = np.array(test_loss).mean() avg_test_loss_val = np.array(test_loss).mean()
if math.isnan(float(avg_test_loss_val)): if math.isnan(float(avg_test_loss_val)):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle import paddle
import unittest import unittest
import six import six
...@@ -140,9 +141,10 @@ def test_main(use_cuda, use_py_func_op, use_parallel_executor): ...@@ -140,9 +141,10 @@ def test_main(use_cuda, use_py_func_op, use_parallel_executor):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
train_cp = compiler.CompiledProgram(fluid.default_main_program())
if use_parallel_executor: if use_parallel_executor:
exe = fluid.ParallelExecutor( train_cp = train_cp.with_data_parallel(loss_name=loss.name)
use_cuda=use_cuda, loss_name=loss.name)
fetch_list = [loss.name] fetch_list = [loss.name]
else: else:
fetch_list = [loss] fetch_list = [loss]
...@@ -150,9 +152,10 @@ def test_main(use_cuda, use_py_func_op, use_parallel_executor): ...@@ -150,9 +152,10 @@ def test_main(use_cuda, use_py_func_op, use_parallel_executor):
ret = [] ret = []
for epoch_id in six.moves.range(2): for epoch_id in six.moves.range(2):
for d in r(): for d in r():
L, = exe.run(feed=feeder.feed(d), fetch_list=fetch_list) L, = exe.run(train_cp,
feed=feeder.feed(d),
fetch_list=fetch_list)
ret.append(L) ret.append(L)
return np.array(ret) return np.array(ret)
......
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import unittest import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle.fluid.core as core import paddle.fluid.core as core
import numpy as np import numpy as np
import threading import threading
...@@ -188,18 +189,18 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -188,18 +189,18 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
startup_exe = fluid.Executor(place) exe = fluid.Executor(place)
startup_exe.run(startup_program) exe.run(startup_program)
train_cp = compiler.CompiledProgram(main_program)
if use_parallel_executor: if use_parallel_executor:
main_exe = fluid.ParallelExecutor(use_cuda, loss_name=loss.name) train_cp = train_cp.with_data_parallel(loss_name=loss.name)
if use_cuda: if use_cuda:
self.batch_size_times = core.get_cuda_device_count() self.batch_size_times = core.get_cuda_device_count()
else: else:
self.batch_size_times = int( self.batch_size_times = int(
os.environ.get('CPU_NUM', multiprocessing.cpu_count())) os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
else: else:
main_exe = startup_exe
self.batch_size_times = 1 self.batch_size_times = 1
reader = self.tensor_reader(use_decorate_paddle_reader) reader = self.tensor_reader(use_decorate_paddle_reader)
...@@ -214,7 +215,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase): ...@@ -214,7 +215,8 @@ class TestPyReaderUsingExecutor(unittest.TestCase):
self.outputs = [] self.outputs = []
for _ in range(self.iterations): for _ in range(self.iterations):
fetches = main_exe.run(fetch_list=[in_data.name, label.name]) fetches = exe.run(train_cp,
fetch_list=[in_data.name, label.name])
fetches = [as_numpy(fetch) for fetch in fetches] fetches = [as_numpy(fetch) for fetch in fetches]
self.outputs.append(fetches) self.outputs.append(fetches)
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import os import os
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
import paddle import paddle
import numpy as np import numpy as np
import unittest import unittest
...@@ -74,20 +75,13 @@ class TestReaderReset(unittest.TestCase): ...@@ -74,20 +75,13 @@ class TestReaderReset(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy() train_cp = compiler.CompiledProgram(main_prog).with_data_parallel()
exec_strategy = fluid.ExecutionStrategy()
parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy,
exec_strategy=exec_strategy)
data_appeared = [False] * self.total_ins_num
pass_count = 0 pass_count = 0
while (True): while (True):
try: try:
data_val, label_val = parallel_exe.run(fetch_list, data_val, label_val = exe.run(train_cp,
return_numpy=True) fetch_list=fetch_list,
return_numpy=True)
ins_num = data_val.shape[0] ins_num = data_val.shape[0]
broadcasted_label = np.ones((ins_num, ) + tuple( broadcasted_label = np.ones((ins_num, ) + tuple(
self.ins_shape)) * label_val.reshape((ins_num, 1)) self.ins_shape)) * label_val.reshape((ins_num, 1))
......
...@@ -22,6 +22,7 @@ import paddle ...@@ -22,6 +22,7 @@ import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import compiler
def get_places(): def get_places():
...@@ -111,17 +112,17 @@ class TestWeightDecay(unittest.TestCase): ...@@ -111,17 +112,17 @@ class TestWeightDecay(unittest.TestCase):
if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce if use_reduce else fluid.BuildStrategy.ReduceStrategy.AllReduce
build_strategy.memory_optimize = use_ir_memory_optimize build_strategy.memory_optimize = use_ir_memory_optimize
parallel_exe = fluid.ParallelExecutor( train_cp = compiler.CompiledProgram(fluid.default_main_program(
use_cuda, )).with_data_parallel(
loss_name=loss.name, loss_name=loss.name,
exec_strategy=exec_strategy, exec_strategy=exec_strategy,
build_strategy=build_strategy) build_strategy=build_strategy)
loss_set = [] loss_set = []
for data in self.train_data: for data in self.train_data:
out = parallel_exe.run(feed=feeder.feed(data), out = exe.run(train_cp,
fetch_list=[loss.name]) feed=feeder.feed(data),
print("loss %s" % (np.average(out))) fetch_list=[loss.name])
loss_set.append(np.average(out)) loss_set.append(np.average(out))
return loss_set return loss_set
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册