未验证 提交 6efc2888 编写于 作者: S Shijie 提交者: GitHub

Fix fuse_gemm_epilogue (#47805)

* Fix fuse_gemm_epilogue

* update tests

* Update CMakeLists.txt

* Update CMakeLists.txt

* Update CMakeLists.txt

* fix random seed

* use assert_allclose

* Update test_dist_fuse_gemm_epilogue_pass.py

* Update cpp_pass.py

* Update test_dist_fuse_gemm_epilogue_pass.py

* fix codestyle

* update seed and atol
上级 4c38b87e
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h" #include "paddle/fluid/framework/ir/fuse_gemm_epilogue_pass.h"
#include <string> #include <string>
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
...@@ -106,13 +106,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph, ...@@ -106,13 +106,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearFwd(ir::Graph *graph,
IR_NODE_LINK_TO(ele_bias, gemm_epilogue_node); IR_NODE_LINK_TO(ele_bias, gemm_epilogue_node);
IR_NODE_LINK_TO(gemm_epilogue_node, ele_out); IR_NODE_LINK_TO(gemm_epilogue_node, ele_out);
GraphSafeRemoveNodes(g, {matmul_op, matmul_out, ele_add_op});
VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name() VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name()
<< " -> " << matmul_op->Name() << " -> " << matmul_out->Name() << " -> " << matmul_op->Name() << " -> " << matmul_out->Name()
<< "\n\t " << matmul_out->Name() << " and " << ele_bias->Name() << "\n\t " << matmul_out->Name() << " and " << ele_bias->Name()
<< " -> " << ele_add_op->Name() << " -> " << ele_out->Name() << " -> " << ele_add_op->Name() << " -> " << ele_out->Name()
<< "\n\t " << ele_out->Name(); << "\n\t " << ele_out->Name();
GraphSafeRemoveNodes(g, {matmul_op, matmul_out, ele_add_op});
found_linear_count++; found_linear_count++;
}; };
...@@ -218,15 +218,15 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd( ...@@ -218,15 +218,15 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActFwd(
IR_NODE_LINK_TO(gemm_epilogue_node, reserve_space_node); IR_NODE_LINK_TO(gemm_epilogue_node, reserve_space_node);
} }
GraphSafeRemoveNodes(g,
{matmul_op, matmul_out, ele_add_op, ele_out, act_op});
VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name() VLOG(4) << "\n\t " << subgraph.at(x)->Name() << " and " << matmul_w->Name()
<< " -> " << matmul_op->Name() << " -> " << matmul_out->Name() << " -> " << matmul_op->Name() << " -> " << matmul_out->Name()
<< "\n\t " << matmul_out->Name() << " and " << ele_bias->Name() << "\n\t " << matmul_out->Name() << " and " << ele_bias->Name()
<< " -> " << ele_add_op->Name() << " -> " << ele_out->Name() << " -> " << ele_add_op->Name() << " -> " << ele_out->Name()
<< "\n\t " << ele_out->Name() << " -> " << act_op->Name() << " -> " << "\n\t " << ele_out->Name() << " -> " << act_op->Name() << " -> "
<< act_out->Name(); << act_out->Name();
GraphSafeRemoveNodes(g,
{matmul_op, matmul_out, ele_add_op, ele_out, act_op});
found_linear_act_count++; found_linear_act_count++;
}; };
...@@ -318,6 +318,19 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -318,6 +318,19 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x); fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y); fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto matmul_grad_op_role_val =
details::GetOpRoleVarsOrEmpty(*(matmul_grad_op->Op()));
auto ele_add_grad_op_role_val =
details::GetOpRoleVarsOrEmpty(*(ele_add_grad_op->Op()));
std::vector<std::string> fused_gemm_epilogue_grad_op_role_var;
for (auto i : matmul_grad_op_role_val) {
fused_gemm_epilogue_grad_op_role_var.push_back(i);
}
for (auto i : ele_add_grad_op_role_val) {
fused_gemm_epilogue_grad_op_role_var.push_back(i);
}
fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role_var", fused_gemm_epilogue_grad_op_role_var);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -325,14 +338,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -325,14 +338,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node); IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node);
IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node); IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node); IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(ele_grad_bias, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw); IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw);
IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias); IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias);
if (matmul_grad_dx) { if (matmul_grad_dx) {
IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dx); IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dx);
} }
GraphSafeRemoveNodes(g, {ele_add_grad_op, ele_grad_dx, matmul_grad_op});
std::string matmul_grad_dx_name = std::string matmul_grad_dx_name =
matmul_grad_dx != nullptr ? matmul_grad_dx->Name() : " "; matmul_grad_dx != nullptr ? matmul_grad_dx->Name() : " ";
VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and " VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and "
...@@ -342,6 +354,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph, ...@@ -342,6 +354,8 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearBwd(ir::Graph *graph,
<< matmul_grad_x->Name() << " and " << matmul_grad_w->Name() << matmul_grad_x->Name() << " and " << matmul_grad_w->Name()
<< " -> " << matmul_grad_op->Name() << " -> " << " -> " << matmul_grad_op->Name() << " -> "
<< matmul_grad_w->Name() << " and " << matmul_grad_dx_name; << matmul_grad_w->Name() << " and " << matmul_grad_dx_name;
GraphSafeRemoveNodes(g, {ele_add_grad_op, ele_grad_dx, matmul_grad_op});
found_ele_add_matmul_act_count++; found_ele_add_matmul_act_count++;
}; };
...@@ -442,6 +456,19 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -442,6 +456,19 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
"op_role", matmul_grad_op_desc->GetAttr("op_role")); "op_role", matmul_grad_op_desc->GetAttr("op_role"));
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x); fused_gemm_epilogue_grad_op_desc.SetAttr("trans_x", trans_x);
fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y); fused_gemm_epilogue_grad_op_desc.SetAttr("trans_y", trans_y);
auto matmul_grad_op_role_val =
details::GetOpRoleVarsOrEmpty(*(matmul_grad_op->Op()));
auto ele_add_grad_op_role_val =
details::GetOpRoleVarsOrEmpty(*(ele_add_grad_op->Op()));
std::vector<std::string> fused_gemm_epilogue_grad_op_role_var;
for (auto i : matmul_grad_op_role_val) {
fused_gemm_epilogue_grad_op_role_var.push_back(i);
}
for (auto i : ele_add_grad_op_role_val) {
fused_gemm_epilogue_grad_op_role_var.push_back(i);
}
fused_gemm_epilogue_grad_op_desc.SetAttr(
"op_role_var", fused_gemm_epilogue_grad_op_role_var);
auto gemm_epilogue_grad_node = auto gemm_epilogue_grad_node =
g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc); g->CreateOpNode(&fused_gemm_epilogue_grad_op_desc);
...@@ -449,18 +476,12 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -449,18 +476,12 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node); IR_NODE_LINK_TO(subgraph.at(dout), gemm_epilogue_grad_node);
IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node); IR_NODE_LINK_TO(matmul_grad_x, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node); IR_NODE_LINK_TO(matmul_grad_w, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(ele_grad_bias, gemm_epilogue_grad_node);
IR_NODE_LINK_TO(gemm_epilogue_grad_node, act_grad_dx); IR_NODE_LINK_TO(gemm_epilogue_grad_node, act_grad_dx);
IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw); IR_NODE_LINK_TO(gemm_epilogue_grad_node, matmul_grad_dw);
IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias); IR_NODE_LINK_TO(gemm_epilogue_grad_node, ele_grad_dbias);
IR_NODE_LINK_TO(reserve_space_node, gemm_epilogue_grad_node); IR_NODE_LINK_TO(reserve_space_node, gemm_epilogue_grad_node);
GraphSafeRemoveNodes(g,
{ele_add_grad_op,
ele_grad_dx,
matmul_grad_op,
matmul_grad_dx,
act_grad_op});
VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and " VLOG(4) << "\n\t " << subgraph.at(dout)->Name() << " and "
<< ele_grad_bias->Name() << " -> " << ele_add_grad_op->Name() << ele_grad_bias->Name() << " -> " << ele_add_grad_op->Name()
<< " -> " << ele_grad_dx->Name() << " and " << " -> " << ele_grad_dx->Name() << " and "
...@@ -470,6 +491,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd( ...@@ -470,6 +491,13 @@ ir::Graph *FuseGemmEpiloguePass::FuseLinearActBwd(
<< matmul_grad_dx->Name() << " and " << matmul_grad_w->Name() << matmul_grad_dx->Name() << " and " << matmul_grad_w->Name()
<< "\n\t " << matmul_grad_dx->Name() << " -> " << "\n\t " << matmul_grad_dx->Name() << " -> "
<< act_grad_op->Name() << " -> " << act_grad_dx->Name(); << act_grad_op->Name() << " -> " << act_grad_dx->Name();
GraphSafeRemoveNodes(g,
{ele_add_grad_op,
ele_grad_dx,
matmul_grad_op,
matmul_grad_dx,
act_grad_op});
found_ele_add_matmul_act_count++; found_ele_add_matmul_act_count++;
}; };
......
...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper): ...@@ -71,6 +71,19 @@ class FuseReluDepthwiseConvPass(CPPPassWrapper):
return PassType.FUSION_OPT return PassType.FUSION_OPT
@register_pass("fuse_gemm_epilogue")
class FuseGemmEpiloguePass(CPPPassWrapper):
def __init__(self):
super().__init__()
@property
def cpp_name(self):
return "fuse_gemm_epilogue_pass"
def _type(self):
return PassType.FUSION_OPT
@register_pass("fuse_optimizer") @register_pass("fuse_optimizer")
class FuseOptimizerPass(CPPPassWrapper): class FuseOptimizerPass(CPPPassWrapper):
def __init__(self): def __init__(self):
......
...@@ -24,6 +24,10 @@ if((NOT WITH_GPU) ...@@ -24,6 +24,10 @@ if((NOT WITH_GPU)
"test_auto_parallel_data_parallel_optimization_pass") "test_auto_parallel_data_parallel_optimization_pass")
endif() endif()
if(NOT ((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6)))
list(REMOVE_ITEM TEST_OPS test_dist_fuse_gemm_epilogue_pass)
endif()
foreach(TEST_OP ${TEST_OPS}) foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP}) py_test_modules(${TEST_OP} MODULES ${TEST_OP})
list(APPEND DIST_TEST_OPS ${TEST_OP}) list(APPEND DIST_TEST_OPS ${TEST_OP})
......
...@@ -105,9 +105,11 @@ class DistPassTestBase(unittest.TestCase): ...@@ -105,9 +105,11 @@ class DistPassTestBase(unittest.TestCase):
if out_var_no_pass is None: if out_var_no_pass is None:
self.assertIsNone(out_var_pass) self.assertIsNone(out_var_pass)
else: else:
self.assertEqual(len(out_var_pass), len(out_var_no_pass))
for i in range(0, len(out_var_pass)):
np.testing.assert_allclose( np.testing.assert_allclose(
out_var_no_pass, out_var_no_pass[i],
out_var_pass, out_var_pass[i],
rtol=self.rtol, rtol=self.rtol,
atol=self.atol, atol=self.atol,
equal_nan=self.equal_nan, equal_nan=self.equal_nan,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from dist_pass_test_base import DistPassTestBase
import paddle
import paddle.distributed.fleet as fleet
import paddle.nn as nn
from paddle.distributed.passes import PassManager, new_pass
paddle.enable_static()
np.random.seed(12345)
paddle.seed(12345)
def verify_op_count(op_types, op_name, target_count):
count = 0
for op_type in op_types:
if op_type == op_name:
count += 1
return count == target_count
class MultiFCLayer(nn.Layer):
def __init__(self, hidden, Activation):
super(MultiFCLayer, self).__init__()
self.linear1 = paddle.nn.Linear(hidden, 4 * hidden)
self.linear2 = paddle.nn.Linear(4 * hidden, hidden)
self.linear3 = paddle.nn.Linear(hidden, hidden)
self.relu1 = Activation()
self.relu2 = Activation()
self.relu3 = Activation()
def forward(self, x, matmul_y, ele_y):
output = self.linear1(x)
output = self.relu1(output)
output = self.linear2(output)
output1 = paddle.matmul(output, matmul_y)
output = self.linear3(output)
output = self.relu2(output)
output = paddle.matmul(output, matmul_y)
output = paddle.add(output, ele_y)
output = self.relu3(output)
output = paddle.add(output, output1)
return output
class TestFuseGemmEpiloguePassReluFP32(DistPassTestBase):
def init(self):
self.atol = 1e-3
self.rtol = 1e-3
self.activation = nn.ReLU
self.act_fwd_name = 'relu'
self.act_bwd_name = 'relu_grad'
self.batch = 64
self.seqlen = 128
self.hidden = 768
self.precision = 'FP32' # FP32 or AMP
def get_model(self, place):
data = paddle.static.data(
name="_data", shape=[-1, self.seqlen, self.hidden], dtype='float32'
)
matmul_y = paddle.static.data(
name="_matmul_y",
shape=[1, self.hidden, self.hidden],
dtype='float32',
)
ele_y = paddle.static.data(
name="_ele_y",
shape=[
self.hidden,
],
dtype='float32',
)
model = MultiFCLayer(self.hidden, self.activation)
out = model(data, matmul_y, ele_y)
loss = paddle.mean(out)
optimizer = paddle.optimizer.Adam(learning_rate=1e-3)
dist_strategy = fleet.DistributedStrategy()
dist_strategy.fuse_all_reduce_ops = False
dist_strategy.without_graph_optimization = True
if self.precision == 'AMP':
dist_strategy.amp = True
dist_strategy.amp_configs = {
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
"custom_white_list": ['gelu'],
}
fleet.init(is_collective=True, strategy=dist_strategy)
optimizer = fleet.distributed_optimizer(optimizer)
optimizer.minimize(loss)
rank = paddle.distributed.get_rank()
def reader():
for _ in range(10):
data_arr = (
np.random.random(
(self.batch, self.seqlen, self.hidden)
).astype("float32")
- 0.5
)
matmul_y_arr = (
np.random.random((1, self.hidden, self.hidden)).astype(
"float32"
)
- 0.5
)
ele_y_arr = (
np.random.random((self.hidden,)).astype("float32") - 0.5
)
yield [data_arr, matmul_y_arr, ele_y_arr]
main_program = paddle.static.default_main_program()
startup_program = paddle.static.default_startup_program()
fetch_list = []
for p in model.parameters():
grad_name = p.name + '@GRAD'
fetch_list.append(grad_name)
fetch_list.append(loss.name)
return (
main_program,
startup_program,
[data, matmul_y, ele_y],
fetch_list,
reader,
)
def apply_passes(self, main_prog, startup_prog):
pass_manager = PassManager([new_pass("fuse_gemm_epilogue")])
pass_manager.apply([main_prog], [startup_prog])
print(pass_manager.names)
op_type = []
for op in main_prog.global_block().ops:
op_type.append(op.type)
print(op_type)
self.assertTrue(verify_op_count(op_type, "fused_gemm_epilogue", 3))
self.assertTrue(verify_op_count(op_type, "fused_gemm_epilogue_grad", 3))
self.assertTrue(verify_op_count(op_type, self.act_fwd_name, 1))
self.assertTrue(verify_op_count(op_type, self.act_bwd_name, 2))
def test_fuse_gemm_epilogue(self):
self.check_main()
class TestFuseGemmEpiloguePassReluFP16(TestFuseGemmEpiloguePassReluFP32):
def init(self):
self.atol = 1e-3
self.rtol = 1e-3
self.activation = nn.ReLU
self.act_fwd_name = 'relu'
self.act_bwd_name = 'relu_grad'
self.batch = 64
self.seqlen = 128
self.hidden = 768
self.precision = 'AMP' # FP32 or AMP
class TestFuseGemmEpiloguePassGeluFP32(TestFuseGemmEpiloguePassReluFP32):
def init(self):
self.atol = 1e-3
self.rtol = 1e-3
self.activation = nn.GELU
self.act_fwd_name = 'gelu'
self.act_bwd_name = 'gelu_grad'
self.batch = 64
self.seqlen = 128
self.hidden = 768
self.precision = 'FP32' # FP32 or AMP
class TestFuseGemmEpiloguePassGeluFP16(TestFuseGemmEpiloguePassReluFP32):
def init(self):
self.atol = 5e-3
self.rtol = 1e-3
self.activation = nn.GELU
self.act_fwd_name = 'gelu'
self.act_bwd_name = 'gelu_grad'
self.batch = 64
self.seqlen = 128
self.hidden = 768
self.precision = 'AMP' # FP32 or AMP
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册