未验证 提交 5cae5fdd 编写于 作者: Y yuehuayingxueluo 提交者: GitHub

Fix bugs in pass_base.py (#50136)

* fix the processing order of passes in pass_base.py

* fix processing order

* add _PASS_PROCESS_ORDER_LIST

* delete some pass in _PASS_PROCESS_ORDER_LIST

* add assert in pass_base.py

* remove fuse_optimizer

* add _fusion_opt_list_rule

* add test_pass_base_list.py

* fix some bug

* add fused_attention

* add some passes to list

* fix ci bug

* fix ci bug
上级 7edfac9e
......@@ -53,6 +53,7 @@ class PassBase(ABC):
_BEFORE_WHITE_LISTS_DICT = {}
_AFTER_WHITE_LISTS_DICT = {}
_PASS_PROCESS_ORDER_LIST = []
name = None
......@@ -176,6 +177,16 @@ def _fusion_opt_last_rule(pass_before, pass_after):
return True
def _fusion_opt_list_rule(pass_before, pass_after):
if (
pass_before._type() == PassType.FUSION_OPT
and pass_after._type() == PassType.FUSION_OPT
):
return _get_list_index(pass_before) < _get_list_index(pass_after)
else:
return True
def _make_rule_from_white_lists_dict(
before_white_lists_dict, after_white_lists_dict
):
......@@ -216,6 +227,13 @@ def _make_rule_from_white_lists_dict(
return rule
def _get_list_index(in_pass):
assert (
in_pass.name in PassBase._PASS_PROCESS_ORDER_LIST
), "Pass {} is not in _PASS_PROCESS_ORDER_LIST".format(in_pass.name)
return PassBase._PASS_PROCESS_ORDER_LIST.index(in_pass.name)
# The key-value pair (k, [v1, v2, ..., vn]) means the pass k can be
# applied before any of pass [v1, v2, ..., vn] is applied
PassBase._BEFORE_WHITE_LISTS_DICT = {
......@@ -229,8 +247,19 @@ PassBase._AFTER_WHITE_LISTS_DICT = {
# Add more white lists here
}
# The index of pass in this list represent the order in which the pass is processed.
PassBase._PASS_PROCESS_ORDER_LIST = [
"fuse_relu_depthwise_conv",
"fuse_bn_add_act",
"fuse_bn_act",
"fused_attention",
"fuse_gemm_epilogue",
"fuse_optimizer",
]
PassBase._COMMON_RULES = [
_fusion_opt_last_rule,
_fusion_opt_list_rule,
lambda pass_before, pass_after: type(pass_before) != type(pass_after),
_make_rule_from_white_lists_dict(
PassBase._BEFORE_WHITE_LISTS_DICT, PassBase._AFTER_WHITE_LISTS_DICT
......@@ -267,7 +296,17 @@ def _find_longest_path(edges):
for j in range(n):
if dists[i][j] > dists[i][k] + dists[k][j]:
dists[i][j] = dists[i][k] + dists[k][j]
paths[i][j] = paths[i][k] + paths[k][j]
if paths[i][k]:
assert paths[i][k][-1] == k
else:
continue
if paths[k][j]:
assert paths[k][j][0] == k
else:
continue
paths[i][j] = (
paths[i][k] + paths[k][j][1:] if paths[k][j] else []
)
if dists[i][j] < min_dist:
min_dist = dists[i][j]
min_path = paths[i][j]
......
......@@ -81,6 +81,8 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 240)
py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass)
set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 20)
py_test_modules(test_pass_base_list MODULES test_pass_base_list)
set_tests_properties(test_pass_base_list PROPERTIES TIMEOUT 20)
py_test_modules(test_while_op_completion MODULES test_while_op_completion
ENVS ${dist_ENVS})
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import sys
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
sys.path.append("..")
from test_sparse_addmm_op import get_cuda_version
def apply_pass(use_fused_passes=False, fused_passes_list=[]):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
fused_passes = strategy.fused_passes
fused_passes.enable = use_fused_passes
fused_passes.fused_passes_list = fused_passes_list
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class TestFusedPassBaseList(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 1
self.batch_num = 1
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_fused_passes=False, fused_passes_list=[]):
reset_prog()
strategy = apply_pass(use_fused_passes, fused_passes_list)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("serial")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses, rtol=None, atol=None):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=rtol or self.rtol,
atol=atol or self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_passes(self):
losses = []
if get_cuda_version() >= 11060:
for use_fused_passes in [True, False]:
engine = self.get_engine(
use_fused_passes,
[
"fuse_bn_act",
"fused_attention",
"fuse_optimizer",
"fuse_gemm_epilogue",
"fuse_bn_add_act",
"fuse_relu_depthwise_conv",
],
)
history = engine.fit(
self.dataset, 3, batch_size=self.batch_size
)
losses.append(np.array(history.history["loss"]))
self.check_results(losses[0], losses[1])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册