未验证 提交 0d12aa64 编写于 作者: S sneaxiy 提交者: GitHub

add check pass conflict tools (#38276)

上级 ac696941
......@@ -315,4 +315,8 @@ class PassManager:
@property
def names(self):
return [p.name for p in self._passes]
return [p.name for p in self.passes]
@property
def passes(self):
return tuple(self._passes)
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from dist_pass_test_base import PassConflictChecker
from paddle.distributed.passes import new_pass
from model_zoo import resnet_model
class CheckPassConflictTest1(PassConflictChecker):
def pass_config(self):
return [
new_pass("fuse_all_reduce", {"max_memory_size": 1024 * 1024}),
new_pass("fuse_elewise_add_act"),
]
def test_resnet(self):
self.check_main(resnet_model, batch_size=32)
class CheckPassConflictTest2(PassConflictChecker):
def pass_config(self):
return [
new_pass("fuse_elewise_add_act"),
new_pass("fuse_all_reduce", {"max_memory_size": 1024 * 1024}),
]
def test_resnet(self):
with self.assertRaises(Exception):
self.check_main(resnet_model, batch_size=32)
if __name__ == "__main__":
unittest.main()
......@@ -15,7 +15,6 @@
import unittest
import paddle
import os
import random
import sys
import pickle
import shlex
......@@ -24,6 +23,7 @@ import inspect
import numpy as np
from collections import OrderedDict
from paddle.distributed.fleet.launch_utils import run_with_coverage
from paddle.distributed.passes.pass_base import new_pass, PassBase, PassManager
def prepare_python_path_and_return_module(path):
......@@ -58,6 +58,9 @@ def remove_path_if_exists(path):
class DistPassTestBase(unittest.TestCase):
def setUp(self):
paddle.enable_static()
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
seed = int(os.environ.get('SEED', -1))
if seed <= 0:
seed = np.random.randint(low=1, high=1000000, size=[1])[0]
......@@ -80,11 +83,11 @@ class DistPassTestBase(unittest.TestCase):
def apply_passes(self, main_prog, startup_prog):
raise NotImplementedError()
def check_main(self, gpus=None, **kwargs):
def check_main(self, model=None, gpus=None, **kwargs):
no_pass_rets = self._distributed_launch(
apply_pass=False, gpus=gpus, **kwargs)
model=model, apply_pass=True, gpus=gpus, **kwargs)
pass_rets = self._distributed_launch(
apply_pass=True, gpus=gpus, **kwargs)
model=model, apply_pass=False, gpus=gpus, **kwargs)
self.check_results(no_pass_rets, pass_rets)
def check_results(self, no_pass_rets, pass_rets):
......@@ -105,7 +108,7 @@ class DistPassTestBase(unittest.TestCase):
equal_nan=self.equal_nan))
@classmethod
def _to_var_names(cls, program, names_or_vars):
def _to_var_names(cls, names_or_vars):
if not isinstance(names_or_vars, (list, tuple)):
names_or_vars = [names_or_vars]
ret_var_names = []
......@@ -116,18 +119,20 @@ class DistPassTestBase(unittest.TestCase):
ret_var_names.append(name_or_var.name)
return ret_var_names
def _run_gpu_main(self, apply_pass, dump_file, **kwargs):
def _run_gpu_main(self, model, apply_pass, dump_file, **kwargs):
gpu_id = int(os.environ.get('FLAGS_selected_gpus', 0))
place = paddle.CUDAPlace(gpu_id)
scope = paddle.static.Scope()
if model is None:
model = self.get_model
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
with paddle.static.scope_guard(scope):
with paddle.fluid.unique_name.guard():
main_prog, startup_prog, inputs, outputs, reader = self.get_model(
main_prog, startup_prog, inputs, outputs, reader = model(
place, **kwargs)
inputs = self._to_var_names(main_prog, inputs)
outputs = self._to_var_names(main_prog, outputs)
inputs = self._to_var_names(inputs)
outputs = self._to_var_names(outputs)
if apply_pass:
self.apply_passes(main_prog, startup_prog)
......@@ -161,7 +166,7 @@ class DistPassTestBase(unittest.TestCase):
int(s.strip()) for s in visible_devices.split(",") if s.strip()
]
def _distributed_launch(self, apply_pass, gpus=None, **kwargs):
def _distributed_launch(self, model, apply_pass, gpus=None, **kwargs):
if gpus is None:
gpus = self._get_default_gpu_lists()
......@@ -176,7 +181,9 @@ class DistPassTestBase(unittest.TestCase):
remove_path_if_exists(output_dir)
os.makedirs(output_dir, mode=777)
input_dump_file = os.path.join(output_dir, 'inputs')
input_dump_file = os.path.join(output_dir, 'inputs.bin')
model_dump_file = os.path.join(output_dir, 'model.bin')
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
run_with_coverage(True)
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
......@@ -189,6 +196,10 @@ class DistPassTestBase(unittest.TestCase):
with open(input_dump_file, 'wb') as f:
pickle.dump(kwargs, f)
if model is not None:
with open(model_dump_file, 'wb') as f:
pickle.dump(model, f)
cmd = [
sys.executable,
"-u",
......@@ -208,23 +219,62 @@ class DistPassTestBase(unittest.TestCase):
input_dump_file,
"--output_dir",
output_dir,
] + (["--apply_pass"] if apply_pass else [])
]
if apply_pass:
cmd += ["--apply_pass"]
if model is not None:
cmd += ["--model_file", model_dump_file]
cmd = [shlex.quote(c) for c in cmd]
prepare_python_path_and_return_module(__file__)
exitcode = os.system(' '.join(cmd))
self.assertEqual(
exitcode, 0,
"Pass failed with apply_pass = {}".format(apply_pass))
"Pass test failed with apply_pass = {}, please view log in {}".
format(apply_pass, output_dir))
results = []
for i in range(num_gpus):
dump_file = '{0}/{1}.bin'.format(output_dir, i)
self.assertTrue(
os.path.exists(dump_file),
"Pass failed with apply_pass = {}".format(apply_pass))
"Pass test failed with apply_pass = {}, please view log in {}".
format(apply_pass, output_dir))
with open(dump_file, "rb") as f:
results.append(pickle.load(f))
return results
finally:
if int(os.environ.get("DEBUG", 0)) == 0:
remove_path_if_exists(output_dir)
class PassConflictChecker(DistPassTestBase):
def setUp(self):
os.environ['DEBUG'] = '1' # to save the debug directory
super(PassConflictChecker, self).setUp()
def pass_config(self):
raise NotImplementedError()
def apply_passes(self, main_prog, startup_prog):
passes = self.pass_config()
if not isinstance(passes, (list, tuple)):
passes = [passes]
for p in passes:
self.assertTrue(isinstance(p, PassBase))
auto_pass_manager = PassManager(passes, auto_solve_conflict=True)
new_passes = auto_pass_manager.passes
self.assertEqual(
len(passes),
len(new_passes),
"After solving conflicts, the left passes are: {}".format(
auto_pass_manager.names))
for i, (p1, p2) in enumerate(zip(passes, new_passes)):
self.assertEqual(
id(p1),
id(p2),
"After solving conflicts, the {}-th pass is different: {} vs {}".
format(i, p1.name, p2.name))
auto_pass_manager.apply([main_prog], [startup_prog])
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
import paddle.distributed.fleet as fleet
from paddle.vision.models import resnet50 as resnet
import numpy as np
import paddle.nn as nn
__all__ = ['resnet_model', ]
def get_seed_from_env():
return int(os.environ.get("SEED", 0))
def resnet_model(place, batch_size, image_shape=[3, 224, 224],
num_classes=1000):
image = paddle.static.data(
shape=[batch_size] + image_shape, dtype='float32', name='image')
label = paddle.static.data(
shape=[batch_size, 1], dtype='int64', name='label')
model = resnet(pretrained=False)
loss_fn = nn.loss.CrossEntropyLoss()
pred_out = model(image)
loss = loss_fn(pred_out, label)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.fuse_all_reduce_ops = False
dist_strategy.without_graph_optimization = True
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(loss)
rank = paddle.distributed.get_rank()
def reader():
seed = get_seed_from_env()
np.random.seed(seed + rank)
for _ in range(10):
image_np = np.random.random(size=image.shape).astype('float32')
label_np = np.random.randint(
low=0, high=num_classes, size=label.shape).astype('int64')
yield image_np, label_np
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
return main_program, startup_program, [image, label], [loss], reader
......@@ -44,6 +44,10 @@ def parse_args():
'--output_dir',
type=str,
help='The output directory to save the logs and output results.')
parser.add_argument(
'--model_file',
type=str,
help='The input model file which contains the dumped model function.')
return parser.parse_args()
......@@ -60,11 +64,16 @@ def run_main(args):
kwargs = pickle.load(f)
output_file = "{}/{}.bin".format(args.output_dir, rank)
if args.model_file:
with open(args.model_file, "rb") as f:
model = pickle.load(f)
else:
model = None
try:
test_obj.setUpClass()
test_obj.setUp()
test_obj._run_gpu_main(args.apply_pass, output_file, **kwargs)
test_obj._run_gpu_main(model, args.apply_pass, output_file, **kwargs)
finally:
test_obj.tearDown()
test_obj.tearDownClass()
......
......@@ -12,20 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.distributed.passes import new_pass, PassManager
import paddle.distributed.fleet as fleet
from paddle.vision.models import resnet50 as resnet
import unittest
from dist_pass_test_base import DistPassTestBase
import paddle.nn as nn
import numpy as np
from model_zoo import resnet_model
class TestFuseAllReducePass(DistPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.atol = 0.0
self.rtol = 0.0
......@@ -35,41 +29,10 @@ class TestFuseAllReducePass(DistPassTestBase):
new_pass("fuse_all_reduce", {"max_memory_size": 1024 * 1024})
])
pass_manager.apply([main_prog], [startup_prog])
print(pass_manager.names)
def test_bs_32(self):
self.check_main(batch_size=32)
def get_model(self, place, batch_size):
image = paddle.static.data(
shape=[batch_size, 3, 224, 224], dtype='float32', name='image')
label = paddle.static.data(
shape=[batch_size, 1], dtype='int64', name='label')
model = resnet(pretrained=False)
loss_fn = nn.loss.CrossEntropyLoss()
pred_out = model(image)
loss = loss_fn(pred_out, label)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.fuse_all_reduce_ops = False
dist_strategy.without_graph_optimization = True
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(loss)
rank = paddle.distributed.get_rank()
def reader():
np.random.seed(self.seed + rank)
for _ in range(10):
image_np = np.random.random(size=image.shape).astype('float32')
label_np = np.random.randint(
low=0, high=1000, size=label.shape).astype('int64')
yield image_np, label_np
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
return main_program, startup_program, [image, label], [loss], reader
self.check_main(resnet_model, batch_size=32)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册