From 39a9abaa80883837c19e2daea5753ab63bef28cf Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 7 Mar 2023 19:34:43 +0800 Subject: [PATCH] [XPU] support shared weight; delete isolated node (#51108) --- paddle/fluid/framework/ir/CMakeLists.txt | 11 +- .../framework/ir/delete_op_device_pass.cc | 8 +- .../ir/delete_op_device_pass_test.cc | 3 +- paddle/fluid/framework/ir/pass.cc | 1 + .../ir/xpu/delete_isolated_node_pass.cc | 203 +++++++++ .../ir/xpu/delete_isolated_node_pass_test.cc | 181 ++++++++ .../framework/ir/xpu/fc_xpu_fuse_pass.cc | 75 ++-- .../ir/xpu/multi_encoder_xpu_fuse_pass.cc | 391 +++++++++++------- paddle/fluid/framework/ir/xpu/pass_utils.cc | 167 ++++++++ paddle/fluid/framework/ir/xpu/pass_utils.h | 21 + paddle/fluid/framework/ir/xpu/quant_utils.cc | 12 +- paddle/fluid/framework/ir/xpu/quant_utils.h | 6 +- .../inference/api/paddle_pass_builder.cc | 1 + .../test_xpu_link_xpu_op_max_pass.py | 6 +- 14 files changed, 863 insertions(+), 223 deletions(-) create mode 100644 paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc create mode 100644 paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index a27780b0254..5dd1b4c6193 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -220,7 +220,7 @@ if(WITH_XPU) cc_library( xpu_pass_utils SRCS xpu/pass_utils.cc - DEPS pass) + DEPS pass xpu_quant_utils) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) @@ -232,6 +232,8 @@ if(WITH_XPU) pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) + pass_library(delete_isolated_node_pass inference DIR xpu DEPS + ${XPU_PASS_DEPS}) endif() cc_library( @@ -484,3 +486,10 @@ if(WITH_MKLDNN) SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass) endif() + +if(WITH_XPU) + cc_test( + test_delete_isolated_node_pass + SRCS xpu/delete_isolated_node_pass_test.cc + DEPS delete_isolated_node_pass) +endif() diff --git a/paddle/fluid/framework/ir/delete_op_device_pass.cc b/paddle/fluid/framework/ir/delete_op_device_pass.cc index dfd174a442a..cc5523abd8e 100644 --- a/paddle/fluid/framework/ir/delete_op_device_pass.cc +++ b/paddle/fluid/framework/ir/delete_op_device_pass.cc @@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass { void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); - int found_subgraph_count = 0; + int delete_counts = 0; for (auto* node : graph->Nodes()) { if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue; node->Op()->RemoveAttr("op_device"); - found_subgraph_count++; + delete_counts++; } - if (found_subgraph_count > 0) { - LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs"; + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " op_device attr"; } } diff --git a/paddle/fluid/framework/ir/delete_op_device_pass_test.cc b/paddle/fluid/framework/ir/delete_op_device_pass_test.cc index c88c3f4fa6a..2b0ac27782b 100644 --- a/paddle/fluid/framework/ir/delete_op_device_pass_test.cc +++ b/paddle/fluid/framework/ir/delete_op_device_pass_test.cc @@ -13,8 +13,7 @@ // limitations under the License. #include - -#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h" +#include "paddle/fluid/framework/ir/pass.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h" namespace paddle { diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index df15fd6d516..c064040cf42 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -49,6 +49,7 @@ static const std::vector support_subgraph_passes = { "fuse_multi_transformer_layer_pass", "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", + "fc_xpu_fuse_pass", "delete_op_device_pass"}; Graph *Pass::Apply(Graph *graph) const { diff --git a/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc new file mode 100644 index 00000000000..41a822e3e2b --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/pass_utils.h" +#include "paddle/fluid/framework/scope.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { + +class DeleteIsolatedNodePass : public Pass { + protected: + void ApplyImpl(Graph* graph) const override; + + private: + void CollectReservedPersistableNodeNames( + Graph* graph, + std::unordered_set* reserved_persistable_node_names) const; + + int RemoveIsolatedNodes( + Graph* graph, + const std::unordered_set& reserved_persistable_node_names, + std::unordered_set* delete_node_names) const; + + int UpdateControlFlowOp( + Graph* graph, + const std::map& block_id_graph_map, + const std::unordered_set& delete_node_names) const; + + const std::map control_flow_op_input_map_{ + {"while", "X"}, + {"conditional_block", "Input"}, + }; +}; + +void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + PADDLE_ENFORCE(graph->IsMainGraph(), + platform::errors::PreconditionNotMet( + "Pass(apply in main graph) will delete isolated nodes in " + "all subgraphs. Do not apply pass in subgraph.")); + + std::unordered_set reserved_persistable_node_names; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + CollectReservedPersistableNodeNames(graph->GetSubGraph(i), + &reserved_persistable_node_names); + } + + int delete_counts = 0; + std::unordered_set delete_node_names; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + delete_counts += RemoveIsolatedNodes(graph->GetSubGraph(i), + reserved_persistable_node_names, + &delete_node_names); + } + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " isolated nodes"; + } + + std::map block_id_graph_map; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + auto* sub_graph = graph->GetSubGraph(i); + for (auto* node : sub_graph->Nodes()) { + if (node->IsVar()) { + block_id_graph_map[node->GetVarNodeBlockId()] = sub_graph; + break; + } + } + } + int update_counts = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + update_counts += UpdateControlFlowOp( + graph->GetSubGraph(i), block_id_graph_map, delete_node_names); + } + if (update_counts > 0) { + LOG(INFO) << "--- update " << update_counts << " control flow ops"; + } +} + +void DeleteIsolatedNodePass::CollectReservedPersistableNodeNames( + Graph* graph, + std::unordered_set* reserved_persistable_node_names) const { + for (auto* node : graph->Nodes()) { + if (!node->IsVar() || !node->Var()->Persistable()) continue; + for (auto* out_node : node->outputs) { + auto op_type = out_node->Op()->Type(); + if (control_flow_op_input_map_.count(op_type) == 0) { + reserved_persistable_node_names->insert(node->Var()->Name()); + break; + } + } + } +} + +int DeleteIsolatedNodePass::RemoveIsolatedNodes( + Graph* graph, + const std::unordered_set& reserved_persistable_node_names, + std::unordered_set* delete_node_names) const { + BlockDesc* block = nullptr; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + block = node->Op()->Block(); + } + } + Scope& scope = graph->Get("__param_scope__"); + + // If graph has nodes to delete: + // 1. Clear var_desc in block + // 2. Clear tensor in variable + // 3. Clear variable in scope + int delete_node_counts = 0; + std::unordered_set delete_nodes; + const std::unordered_set nodes = graph->Nodes(); + for (auto* node : nodes) { + if (!node->IsVar() || !node->Var()->Persistable()) continue; + auto name = node->Var()->Name(); + if (reserved_persistable_node_names.count(name) > 0) continue; + delete_nodes.insert(node); + delete_node_names->insert(node->Name()); + block->RemoveVar(name); + auto* var = scope.FindVar(name); + if (var != nullptr) { + var->Clear(); + scope.EraseVars({name}); + } + delete_node_counts++; + } + + GraphSafeRemoveNodes(graph, delete_nodes); + return delete_node_counts; +} + +int DeleteIsolatedNodePass::UpdateControlFlowOp( + Graph* graph, + const std::map& block_id_graph_map, + const std::unordered_set& delete_node_names) const { + int update_counts = 0; + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) continue; + auto op_type = node->Op()->Type(); + if (control_flow_op_input_map_.count(op_type) == 0) continue; + + auto in_arg_name = control_flow_op_input_map_.at(op_type); + auto in_name = node->Op()->Input(in_arg_name); + std::unordered_set in_names_set(in_name.begin(), + in_name.end()); + for (auto delete_node_name : delete_node_names) { + if (in_names_set.count(delete_node_name) > 0) { + in_names_set.erase(delete_node_name); + } + } + + auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*, + node->Op()->GetAttr("sub_block")); + auto* sub_graph = block_id_graph_map.at(sub_block->ID()); + std::unordered_set sub_persistable_node_names; + CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names); + for (auto sub_name : sub_persistable_node_names) { + if (in_names_set.count(sub_name) > 0) continue; + auto* in_node = FindNodeWithName(graph, sub_name); + if (in_node == nullptr) continue; + in_names_set.insert(sub_name); + IR_NODE_LINK_TO(in_node, node); + } + std::vector new_in_names(in_names_set.begin(), + in_names_set.end()); + node->Op()->SetInput(in_arg_name, new_in_names); + update_counts++; + } + return update_counts; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_isolated_node_pass, + paddle::framework::ir::DeleteIsolatedNodePass); diff --git a/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc new file mode 100644 index 00000000000..1af69930d2e --- /dev/null +++ b/paddle/fluid/framework/ir/xpu/delete_isolated_node_pass_test.cc @@ -0,0 +1,181 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +VarDesc* Data(paddle::framework::BlockDesc* block, + std::string name, + std::vector shape = {}, + bool is_persistable = false, + proto::VarType::Type data_type = proto::VarType::FP32) { + auto* var = block->Var(name); + var->SetType(proto::VarType::LOD_TENSOR); + var->SetDataType(data_type); + var->SetShape(shape); + var->SetPersistable(is_persistable); + return var; +} + +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + auto* data = cpu_ctx->Alloc(tensor); + int64_t numel = tensor->numel(); + for (int64_t i = 0; i < numel; ++i) { + data[i] = 1; + } +} + +Scope* CreateParamScope() { + auto param_scope = new Scope(); + AddVarToScope(param_scope, "matmul0_w", {128, 128}); + return param_scope; +} + +int WeightNodeNum(ir::Graph* graph) { + int num = 0; + for (auto node : graph->Nodes()) { + if (node->IsVar() && node->Var()->Persistable()) { + num++; + } + } + return num; +} + +int WeightTensorNum(Scope* scope) { + int num = 0; + auto vars = scope->LocalVars(); + for (auto* var : vars) { + if (var->Get().numel() > 0) { + num++; + } + } + return num; +} + +TEST(delete_isolated_node_pass, basic) { + paddle::framework::ProgramDesc program; + auto* block0 = program.MutableBlock(0); + auto* block1 = program.AppendBlock(*block0); + + auto* matmul0_x = Data(block0, "matmul0_x", {1, 128}); + auto* matmul0_w = Data(block0, "matmul0_w", {128, 128}, true); + auto* matmul0_out = Data(block0, "matmul0_out", {1, 128}); + OpDesc* matmul_op = block0->AppendOp(); + matmul_op->SetType("matmul_v2"); + matmul_op->SetInput("X", {matmul0_x->Name()}); + matmul_op->SetInput("Y", {matmul0_w->Name()}); + matmul_op->SetAttr("trans_x", false); + matmul_op->SetAttr("trans_y", false); + matmul_op->SetOutput("Out", {matmul0_out->Name()}); + + auto* while_out = Data(block0, "while_out", {1, 128}); + auto* while_step_scopes = Data(block0, "while_step_scopes"); + auto* while_cond = Data(block0, "while_cond"); + OpDesc* while_op = block0->AppendOp(); + while_op->SetType("while"); + while_op->SetInput("X", {matmul0_w->Name(), matmul0_out->Name()}); + while_op->SetInput("Condition", {while_cond->Name()}); + while_op->SetOutput("Out", {while_out->Name()}); + while_op->SetOutput("StepScopes", {while_step_scopes->Name()}); + while_op->SetAttr("sub_block", {block1}); + while_op->SetAttr("is_test", true); + + auto* matmul1_x = Data(block1, matmul0_out->Name(), matmul0_out->GetShape()); + auto* matmul1_w = + Data(block1, matmul0_w->Name(), matmul0_w->GetShape(), true); + auto* matmul1_out = Data(block1, "matmul1_out", {1, 128}); + OpDesc* matmul1_op = block1->AppendOp(); + matmul1_op->SetType("matmul_v2"); + matmul1_op->SetInput("X", {matmul1_x->Name()}); + matmul1_op->SetInput("Y", {matmul1_w->Name()}); + matmul1_op->SetAttr("trans_x", false); + matmul1_op->SetAttr("trans_y", false); + matmul1_op->SetOutput("Out", {matmul1_out->Name()}); + + std::unique_ptr graph(new ir::Graph(program)); + auto* scope = CreateParamScope(); + graph->Set("__param_scope__", scope); + auto pass0 = PassRegistry::Instance().Get("fc_xpu_fuse_pass"); + pass0->Apply(graph.get()); + pass0->Apply(graph->GetSubGraph(1)); + int weight_node_num = + WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1)); + PADDLE_ENFORCE_EQ(weight_node_num, + 6, + platform::errors::PreconditionNotMet( + "Graph should have 6 weight node after " + "fc_xpu_fuse_pass, but actually has %d.", + weight_node_num)); + + auto pass1 = PassRegistry::Instance().Get("delete_isolated_node_pass"); + pass1->Apply(graph.get()); + weight_node_num = + WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1)); + PADDLE_ENFORCE_EQ(weight_node_num, + 4, + platform::errors::PreconditionNotMet( + "Graph should have 4 weight node after " + "delete_isolated_node_pass, but actually has %d.", + weight_node_num)); + int weight_tensor_num = WeightTensorNum(scope); + PADDLE_ENFORCE_EQ(weight_tensor_num, + 2, + platform::errors::PreconditionNotMet( + "Scope should have 2 weight tensor after " + "delete_isolated_node_pass, but actually has %d.", + weight_tensor_num)); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "while") { + auto while_in_names = node->Op()->Inputs().at("X"); + PADDLE_ENFORCE_EQ(while_in_names.size(), + 3, + platform::errors::PreconditionNotMet( + "While op should have 3 input after " + "delete_isolated_node_pass, but actually has %d.", + while_in_names.size())); + } + } + + Scope& scope0 = graph->Get("__param_scope__"); + Scope& scope1 = + graph->GetSubGraph(1)->Get("__param_scope__"); + std::vector shared_weight_names{matmul0_w->Name() + "_int16", + matmul0_w->Name() + "_max"}; + for (auto name : shared_weight_names) { + auto* var0 = scope0.FindVar(name); + auto* var1 = scope1.FindVar(name); + PADDLE_ENFORCE( + var0 == var1, + platform::errors::PreconditionNotMet( + "Variables with the same name in two scopes is different.")); + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +USE_PASS(delete_isolated_node_pass); diff --git a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc index c758d45622f..e37eaceb4f7 100644 --- a/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/fc_xpu_fuse_pass.cc @@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, auto* mul_w = pattern->NewNode(mul_w_repr()) ->assert_is_op_input(mul_type_, "Y") ->assert_is_persistable_var() - ->assert_has_n_outputs(1) ->assert_more([](Node* node) { return node->Var()->GetShape().size() == 2; }); @@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase { void ApplyImpl(ir::Graph* graph) const override; private: - void ApplyImpl(ir::Graph* graph, - const std::string& mul_type, - bool with_bias, - const std::string& act_type) const; + int ApplyImpl(ir::Graph* graph, + const std::string& mul_type, + bool with_bias, + const std::string& act_type) const; const std::string name_scope_{"fc_xpu_fuse_pass"}; }; @@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { PADDLE_ENFORCE_NOT_NULL( graph, platform::errors::PreconditionNotMet("graph should not be null.")); Init(name_scope_, graph); + + int found_subgraph_count = 0; for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { for (auto with_bias : {true, false}) { for (auto act_type : { @@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { "tanh", "", }) { - ApplyImpl(graph, mul_type, with_bias, act_type); + found_subgraph_count += ApplyImpl(graph, mul_type, with_bias, act_type); } } } + AddStatis(found_subgraph_count); } -void FcXPUFusePass::ApplyImpl(ir::Graph* graph, - const std::string& mul_type, - bool with_bias, - const std::string& act_type) const { +int FcXPUFusePass::ApplyImpl(ir::Graph* graph, + const std::string& mul_type, + bool with_bias, + const std::string& act_type) const { GraphPatternDetector gpd; patterns::FcXPUPattern pattern( gpd.mutable_pattern(), name_scope_, mul_type, with_bias, act_type); @@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, auto* block = mul->Op()->Block(); auto* scope = param_scope(); - auto mul_w_name = mul_w->Name(); - auto mul_w_tensor = - scope->FindVar(mul_w_name)->GetMutable(); - // 1. Transform weight to int16/int31 - // 2. Avoid transform repeatly, because weight may be shared with other ops. - // TODO(zhupengyang): support int31 - std::string mul_w_max_name = mul_w_name + "_max"; - Node* mul_w_max = nullptr; - if (mul_w_tensor->dtype() != phi::DataType::INT16) { - // Create weight_max node - VarDesc mul_w_max_desc(mul_w_max_name); - mul_w_max_desc.SetPersistable(true); - mul_w_max = graph->CreateVarNode(&mul_w_max_desc); - // Create weight_max var/tensor - auto mul_w_max_var = block->Var(mul_w_max_name); - mul_w_max_var->SetPersistable(true); - auto mul_w_max_tensor = - scope->Var(mul_w_max_name)->GetMutable(); - bool transpose_w = false; - if (mul_type == "matmul") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); - } else if (mul_type == "matmul_v2") { - transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); - } - QuantWeight(mul_w_tensor, mul_w_max_tensor, !transpose_w); + bool transpose_w = false; + if (mul_type == "matmul") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); + } else if (mul_type == "matmul_v2") { + transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); } + Node* mul_w_int16 = nullptr; + Node* mul_w_max = nullptr; + PrepareWeight( + graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w); + Node* bias_fp32 = nullptr; if (bias != nullptr) { - auto* bias_tensor = - scope->Var(bias->Name())->GetMutable(); - CastToFp32(bias_tensor); + PrepareBias(graph, scope, block, bias, &bias_fp32); } std::string fc_out_name; @@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, framework::OpDesc fc_xpu_op_desc(block); fc_xpu_op_desc.SetType("fc_xpu"); fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); - fc_xpu_op_desc.SetInput("w", {mul_w->Name()}); - fc_xpu_op_desc.SetInput("w_max", {mul_w_max_name}); - if (bias) { - fc_xpu_op_desc.SetInput("bias", {bias->Name()}); + fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()}); + fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()}); + if (bias_fp32) { + fc_xpu_op_desc.SetInput("bias", {bias_fp32->Name()}); } fc_xpu_op_desc.SetAttr( "in_num_col_dims", @@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); IR_NODE_LINK_TO(mul_x, fc_xpu); - IR_NODE_LINK_TO(mul_w, fc_xpu); + IR_NODE_LINK_TO(mul_w_int16, fc_xpu); IR_NODE_LINK_TO(mul_w_max, fc_xpu); - SAFE_IR_NODE_LINK_TO(bias, fc_xpu); + SAFE_IR_NODE_LINK_TO(bias_fp32, fc_xpu); if (act_out) { IR_NODE_LINK_TO(fc_xpu, act_out); } else if (add_out) { @@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, }; gpd(graph, handler); - AddStatis(found_subgraph_count); + return found_subgraph_count; } } // namespace ir diff --git a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc index 1c31db9810b..1053e20150a 100644 --- a/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc +++ b/paddle/fluid/framework/ir/xpu/multi_encoder_xpu_fuse_pass.cc @@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase { // 2. Concat q_w, k_w, v_w // 3. Generate qkv_w_max tensor // 4. Quant qkv_w to int16 - void PrepareQKVWeight(const phi::DenseTensor& q_w, - const phi::DenseTensor& k_w, - const phi::DenseTensor& v_w, - phi::DenseTensor* qkv_w, - phi::DenseTensor* qkv_w_max) const; - - void ConcatQKVBias(const phi::DenseTensor& q_bias, - const phi::DenseTensor& k_bias, - const phi::DenseTensor& v_bias, - phi::DenseTensor* qkv_bias) const; + void PrepareQKVWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_w, + Node* k_w, + Node* v_w, + Node** qkv_w, + Node** qkv_w_max) const; + + // 1. Cast bias to fp32 + // 2. Concat q/k/v bias + void PrepareQKVBias(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_bias, + Node* k_bias, + Node* v_bias, + Node** qkv_bias) const; const std::string name_scope_{"multi_encoder_xpu_fuse_pass"}; }; @@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { AddStatis(cast_mask_counts); } -void MultiEncoderXPUFusePass::PrepareQKVWeight( - const phi::DenseTensor& q_w, - const phi::DenseTensor& k_w, - const phi::DenseTensor& v_w, - phi::DenseTensor* qkv_w, - phi::DenseTensor* qkv_w_max) const { - // Transpose - phi::DenseTensor q_w_t; - phi::DenseTensor k_w_t; - phi::DenseTensor v_w_t; - Assign(q_w, &q_w_t); - Assign(k_w, &k_w_t); - Assign(v_w, &v_w_t); - Transpose2D(&q_w_t); - Transpose2D(&k_w_t); - Transpose2D(&v_w_t); - - // Concat - qkv_w->Resize(DDim( - {q_w_t.dims()[0] + k_w_t.dims()[0] + v_w_t.dims()[0], q_w_t.dims()[1]})); - qkv_w->set_type(q_w.type()); - auto* dev_ctx = static_cast( +void MultiEncoderXPUFusePass::PrepareQKVWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_w, + Node* k_w, + Node* v_w, + Node** qkv_w_int16, + Node** qkv_w_max) const { + phi::DenseTensor q_w_fp32_t; + phi::DenseTensor k_w_fp32_t; + phi::DenseTensor v_w_fp32_t; + Assign(scope->Var(q_w->Name())->Get(), &q_w_fp32_t); + Assign(scope->Var(k_w->Name())->Get(), &k_w_fp32_t); + Assign(scope->Var(v_w->Name())->Get(), &v_w_fp32_t); + + CastToFp32(&q_w_fp32_t); + CastToFp32(&k_w_fp32_t); + CastToFp32(&v_w_fp32_t); + + Transpose2D(&q_w_fp32_t); + Transpose2D(&k_w_fp32_t); + Transpose2D(&v_w_fp32_t); + + phi::DenseTensor qkv_w_int16_t; + phi::DenseTensor qkv_w_max_t; + qkv_w_int16_t.Resize( + DDim({q_w_fp32_t.dims()[0] + k_w_fp32_t.dims()[0] + v_w_fp32_t.dims()[0], + q_w_fp32_t.dims()[1]})); + qkv_w_int16_t.set_type(q_w_fp32_t.type()); + auto* cpu_ctx = static_cast( platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); - std::vector in_tensors{&q_w_t, &k_w_t, &v_w_t}; - if (q_w.type() == phi::DataType::FLOAT16) { - phi::ConcatKernel(*dev_ctx, in_tensors, 0, qkv_w); + std::vector in_tensors{ + &q_w_fp32_t, &k_w_fp32_t, &v_w_fp32_t}; + phi::ConcatKernel(*cpu_ctx, in_tensors, 0, &qkv_w_int16_t); + + PrepareWeight(&qkv_w_int16_t, &qkv_w_max_t, false); + size_t qkv_w_int16_hash = HashTensor(qkv_w_int16_t); + size_t qkv_w_max_hash = HashTensor(qkv_w_max_t); + std::string qkv_w_int16_name = std::to_string(qkv_w_int16_hash); + std::string qkv_w_max_name = std::to_string(qkv_w_max_hash); + *qkv_w_int16 = FindNodeWithName(graph, qkv_w_int16_name); + if (*qkv_w_int16 == nullptr) { + // Create qkv_w_int16 node + // Update qkv_w_int16 var_desc in block + VarDesc qkv_w_int16_desc(qkv_w_int16_name); + qkv_w_int16_desc.SetPersistable(true); + qkv_w_int16_desc.SetShape(vectorize(qkv_w_int16_t.dims())); + qkv_w_int16_desc.SetDataType( + framework::TransToProtoVarType(qkv_w_int16_t.dtype())); + *qkv_w_int16 = graph->CreateVarNode(&qkv_w_int16_desc); + auto* block_qkv_w_int16_desc = block->Var(qkv_w_int16_name); + block_qkv_w_int16_desc->SetPersistable(qkv_w_int16_desc.Persistable()); + block_qkv_w_int16_desc->SetShape(qkv_w_int16_desc.GetShape()); + block_qkv_w_int16_desc->SetDataType(qkv_w_int16_desc.GetDataType()); + // Create qkv_w_max node + // Update qkv_w_max var_desc in block + VarDesc qkv_w_max_desc(qkv_w_max_name); + qkv_w_max_desc.SetPersistable(true); + qkv_w_max_desc.SetShape(vectorize(qkv_w_max_t.dims())); + qkv_w_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + *qkv_w_max = graph->CreateVarNode(&qkv_w_max_desc); + auto* block_qkv_w_max_desc = block->Var(qkv_w_max_name); + block_qkv_w_max_desc->SetPersistable(qkv_w_max_desc.Persistable()); + block_qkv_w_max_desc->SetShape(qkv_w_max_desc.GetShape()); + block_qkv_w_max_desc->SetDataType(qkv_w_max_desc.GetDataType()); + + // Find qkv_w_int16/qkv_w_max variable in scope + auto* qkv_w_int16_var = scope->FindVar(qkv_w_int16_name); + if (qkv_w_int16_var == nullptr) { + // Create qkv_w_int16/qkv_w_max variable/tensor + Assign(qkv_w_int16_t, + scope->Var(qkv_w_int16_name)->GetMutable()); + Assign(qkv_w_max_t, + scope->Var(qkv_w_max_name)->GetMutable()); + } else { + // Share the same variable + PADDLE_ENFORCE_NOT_NULL( + scope->FindVar(qkv_w_max_name), + platform::errors::Fatal( + "qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) " + "variable is exist.", + qkv_w_max_name, + qkv_w_int16_name)); + } } else { - phi::ConcatKernel(*dev_ctx, in_tensors, 0, qkv_w); + *qkv_w_max = FindNodeWithName(graph, qkv_w_max_name); + PADDLE_ENFORCE_NOT_NULL( + *qkv_w_max, + platform::errors::Fatal( + "qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) " + "variable is exist.", + qkv_w_max_name, + qkv_w_int16_name)); } - - // Quant to int16 - QuantWeight(qkv_w, qkv_w_max, false); } -void MultiEncoderXPUFusePass::ConcatQKVBias(const phi::DenseTensor& q_bias, - const phi::DenseTensor& k_bias, - const phi::DenseTensor& v_bias, - phi::DenseTensor* qkv_bias) const { - int q_bias_size = q_bias.numel(); - qkv_bias->Resize(DDim({q_bias_size * 3})); - qkv_bias->set_type(q_bias.type()); - auto* dev_ctx = static_cast( +void MultiEncoderXPUFusePass::PrepareQKVBias(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* q_bias, + Node* k_bias, + Node* v_bias, + Node** qkv_bias) const { + auto* q_bias_tensor = + scope->Var(q_bias->Name())->GetMutable(); + auto* k_bias_tensor = + scope->Var(k_bias->Name())->GetMutable(); + auto* v_bias_tensor = + scope->Var(v_bias->Name())->GetMutable(); + phi::DenseTensor q_bias_fp32_tensor; + phi::DenseTensor k_bias_fp32_tensor; + phi::DenseTensor v_bias_fp32_tensor; + CastToFp32(q_bias_tensor, &q_bias_fp32_tensor); + CastToFp32(k_bias_tensor, &k_bias_fp32_tensor); + CastToFp32(v_bias_tensor, &v_bias_fp32_tensor); + + phi::DenseTensor qkv_bias_tensor; + int q_bias_fp32_size = q_bias_fp32_tensor.numel(); + qkv_bias_tensor.Resize(DDim({q_bias_fp32_size * 3})); + qkv_bias_tensor.set_type(phi::DataType::FLOAT32); + auto* cpu_ctx = static_cast( platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); - auto* qkv_bias_data = dev_ctx->Alloc(qkv_bias); - memcpy(qkv_bias_data, q_bias.data(), q_bias_size * sizeof(float)); - qkv_bias_data += q_bias_size; - memcpy(qkv_bias_data, k_bias.data(), q_bias_size * sizeof(float)); - qkv_bias_data += q_bias_size; - memcpy(qkv_bias_data, v_bias.data(), q_bias_size * sizeof(float)); + auto* qkv_bias_data = cpu_ctx->Alloc(&qkv_bias_tensor); + memcpy(qkv_bias_data, + q_bias_fp32_tensor.data(), + q_bias_fp32_size * sizeof(float)); + qkv_bias_data += q_bias_fp32_size; + memcpy(qkv_bias_data, + k_bias_fp32_tensor.data(), + q_bias_fp32_size * sizeof(float)); + qkv_bias_data += q_bias_fp32_size; + memcpy(qkv_bias_data, + v_bias_fp32_tensor.data(), + q_bias_fp32_size * sizeof(float)); + + size_t qkv_bias_hash = HashTensor(qkv_bias_tensor); + std::string qkv_bias_name = std::to_string(qkv_bias_hash); + *qkv_bias = FindNodeWithName(graph, qkv_bias_name); + if (*qkv_bias == nullptr) { + // Create qkv_bias node + // Update qkv_bias var_desc in block + VarDesc qkv_bias_desc(qkv_bias_name); + qkv_bias_desc.SetPersistable(true); + qkv_bias_desc.SetShape(vectorize(qkv_bias_tensor.dims())); + qkv_bias_desc.SetDataType( + framework::TransToProtoVarType(qkv_bias_tensor.dtype())); + *qkv_bias = graph->CreateVarNode(&qkv_bias_desc); + auto* block_qkv_bias_desc = block->Var(qkv_bias_name); + block_qkv_bias_desc->SetPersistable(qkv_bias_desc.Persistable()); + block_qkv_bias_desc->SetShape(qkv_bias_desc.GetShape()); + block_qkv_bias_desc->SetDataType(qkv_bias_desc.GetDataType()); + Assign(qkv_bias_tensor, + scope->Var(qkv_bias_name)->GetMutable()); + } } int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( @@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( scope->FindVar(q_matmul_w->Name())->Get().dtype() == phi::DataType::FLOAT16; - // Prepare q,k,v weight - std::string q_w_name = q_matmul_w->Name(); - std::string k_w_name = k_matmul_w->Name(); - std::string v_w_name = v_matmul_w->Name(); - std::string qkv_w_name = q_w_name + "_" + k_w_name + "_" + v_w_name; - VarDesc qkv_w_desc(qkv_w_name); - qkv_w_desc.SetPersistable(true); - auto* qkv_w = graph->CreateVarNode(&qkv_w_desc); - auto* qkv_w_var = block->Var(qkv_w_name); - qkv_w_var->SetPersistable(true); - std::string qkv_w_max_name = qkv_w_name + "_max"; - VarDesc qkv_w_max_desc(qkv_w_max_name); - qkv_w_max_desc.SetPersistable(true); - auto* qkv_w_max = graph->CreateVarNode(&qkv_w_max_desc); - auto* qkv_w_max_var = block->Var(qkv_w_max_name); - qkv_w_max_var->SetPersistable(true); - PrepareQKVWeight( - scope->FindVar(q_w_name)->Get(), - scope->FindVar(k_w_name)->Get(), - scope->FindVar(v_w_name)->Get(), - scope->Var(qkv_w_name)->GetMutable(), - scope->Var(qkv_w_max_name)->GetMutable()); - - // Prepare qkv_matmul_1_w, qkv_matmul_2_w, qkv_matmul_3_w -#define PREPARE_QKV_MATMUL_W(idx_) \ - std::string qkv_matmul_##idx_##_w_name = qkv_matmul_##idx_##_w->Name(); \ - std::string qkv_matmul_##idx_##_w_max_name = \ - qkv_matmul_##idx_##_w_name + "_max"; \ - VarDesc qkv_matmul_##idx_##_w_max_desc(qkv_matmul_##idx_##_w_max_name); \ - qkv_matmul_##idx_##_w_max_desc.SetPersistable(true); \ - auto qkv_matmul_##idx_##_w_max = \ - graph->CreateVarNode(&qkv_matmul_##idx_##_w_max_desc); \ - auto qkv_matmul_##idx_##_w_max_var = \ - block->Var(qkv_matmul_##idx_##_w_max_name); \ - qkv_matmul_##idx_##_w_max_var->SetPersistable(true); \ - auto qkv_matmul_##idx_##_w_max_tensor = \ - scope->Var(qkv_matmul_##idx_##_w_max_name) \ - ->GetMutable(); \ - auto qkv_matmul_##idx_##_w_tensor = \ - scope->Var(qkv_matmul_##idx_##_w_name)->GetMutable(); \ - QuantWeight( \ - qkv_matmul_##idx_##_w_tensor, qkv_matmul_##idx_##_w_max_tensor, true); + Node* qkv_w_int16 = nullptr; + Node* qkv_w_max = nullptr; + PrepareQKVWeight(graph, + scope, + block, + q_matmul_w, + k_matmul_w, + v_matmul_w, + &qkv_w_int16, + &qkv_w_max); + +#define PREPARE_QKV_MATMUL_W(idx_) \ + Node* qkv_matmul_##idx_##_w_int16 = nullptr; \ + Node* qkv_matmul_##idx_##_w_max = nullptr; \ + PrepareWeight(graph, \ + scope, \ + block, \ + qkv_matmul_##idx_##_w, \ + &qkv_matmul_##idx_##_w_int16, \ + &qkv_matmul_##idx_##_w_max, \ + true); PREPARE_QKV_MATMUL_W(1); PREPARE_QKV_MATMUL_W(2); PREPARE_QKV_MATMUL_W(3); #undef PREPARE_QKV_MATMUL_W - // Concat q_add_bias, k_add_bias, v_add_bias - std::string q_add_bias_name = q_add_bias->Name(); - std::string k_add_bias_name = k_add_bias->Name(); - std::string v_add_bias_name = v_add_bias->Name(); - std::string qkv_add_bias_name = - q_add_bias_name + "_" + k_add_bias_name + "_" + v_add_bias_name; - VarDesc qkv_add_bias_desc(qkv_add_bias_name); - qkv_add_bias_desc.SetPersistable(true); - auto* qkv_add_bias = graph->CreateVarNode(&qkv_add_bias_desc); - auto* qkv_add_bias_var = block->Var(qkv_add_bias_name); - qkv_add_bias_var->SetPersistable(true); - auto* q_add_bias_tensor = - scope->FindVar(q_add_bias_name)->GetMutable(); - auto* k_add_bias_tensor = - scope->FindVar(k_add_bias_name)->GetMutable(); - auto* v_add_bias_tensor = - scope->FindVar(v_add_bias_name)->GetMutable(); - CastToFp32(q_add_bias_tensor); - CastToFp32(k_add_bias_tensor); - CastToFp32(v_add_bias_tensor); - ConcatQKVBias( - *q_add_bias_tensor, - *k_add_bias_tensor, - *v_add_bias_tensor, - scope->Var(qkv_add_bias_name)->GetMutable()); - - // Prepare qkv_add_0_bias, qkv_add_2_bias, qkv_add_3_bias - auto qkv_add_0_bias_name = qkv_add_0_bias->Name(); - CastToFp32( - scope->FindVar(qkv_add_0_bias_name)->GetMutable()); - auto qkv_add_2_bias_name = qkv_add_2_bias->Name(); - CastToFp32( - scope->FindVar(qkv_add_2_bias_name)->GetMutable()); - auto qkv_add_3_bias_name = qkv_add_3_bias->Name(); - CastToFp32( - scope->FindVar(qkv_add_3_bias_name)->GetMutable()); + Node* qkv_add_bias_fp32 = nullptr; + PrepareQKVBias(graph, + scope, + block, + q_add_bias, + k_add_bias, + v_add_bias, + &qkv_add_bias_fp32); + + Node* qkv_add_0_bias_fp32 = nullptr; + Node* qkv_add_2_bias_fp32 = nullptr; + Node* qkv_add_3_bias_fp32 = nullptr; + PrepareBias(graph, scope, block, qkv_add_0_bias, &qkv_add_0_bias_fp32); + PrepareBias(graph, scope, block, qkv_add_2_bias, &qkv_add_2_bias_fp32); + PrepareBias(graph, scope, block, qkv_add_3_bias, &qkv_add_3_bias_fp32); // Generate single_encoder_xpu op framework::OpDesc op_desc(block); op_desc.SetType("single_encoder_xpu"); op_desc.SetInput("x", {ln_0_x->Name()}); op_desc.SetInput("fc_weight", - {qkv_w_name, - qkv_matmul_1_w_name, - qkv_matmul_2_w_name, - qkv_matmul_3_w_name}); + {qkv_w_int16->Name(), + qkv_matmul_1_w_int16->Name(), + qkv_matmul_2_w_int16->Name(), + qkv_matmul_3_w_int16->Name()}); op_desc.SetInput("fc_weight_max", - {qkv_w_max_name, - qkv_matmul_1_w_max_name, - qkv_matmul_2_w_max_name, - qkv_matmul_3_w_max_name}); + {qkv_w_max->Name(), + qkv_matmul_1_w_max->Name(), + qkv_matmul_2_w_max->Name(), + qkv_matmul_3_w_max->Name()}); op_desc.SetInput("fc_bias", - {qkv_add_bias_name, - qkv_add_0_bias_name, - qkv_add_2_bias_name, - qkv_add_3_bias_name}); + {qkv_add_bias_fp32->Name(), + qkv_add_0_bias_fp32->Name(), + qkv_add_2_bias_fp32->Name(), + qkv_add_3_bias_fp32->Name()}); if (norm_before) { op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()}); op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()}); @@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( } auto* single_encoder_xpu = graph->CreateOpNode(&op_desc); // Link nodes - SAFE_IR_NODE_LINK_TO(ln_0_x, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_w, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_add_bias, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_add_0_bias, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_add_2_bias, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(qkv_add_3_bias, single_encoder_xpu); + IR_NODE_LINK_TO(ln_0_x, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_w_int16, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_1_w_int16, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_2_w_int16, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_3_w_int16, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_add_bias_fp32, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_add_0_bias_fp32, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_add_2_bias_fp32, single_encoder_xpu); + IR_NODE_LINK_TO(qkv_add_3_bias_fp32, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_0_scale, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_0_bias, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(ln_1_scale, single_encoder_xpu); - SAFE_IR_NODE_LINK_TO(ln_1_bias, single_encoder_xpu); + IR_NODE_LINK_TO(ln_1_scale, single_encoder_xpu); + IR_NODE_LINK_TO(ln_1_bias, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_2_scale, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_2_bias, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(qk_add_mask, single_encoder_xpu); if (norm_before) { - SAFE_IR_NODE_LINK_TO(single_encoder_xpu, qkv_add_4_out); + IR_NODE_LINK_TO(single_encoder_xpu, qkv_add_4_out); } else { - SAFE_IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out); + IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out); } // Delete nodes diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.cc b/paddle/fluid/framework/ir/xpu/pass_utils.cc index 262af5805b8..d23049eb92c 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.cc +++ b/paddle/fluid/framework/ir/xpu/pass_utils.cc @@ -20,6 +20,18 @@ namespace paddle { namespace framework { namespace ir { +static void HashCombine(std::size_t* seed) {} + +// combine hash value +// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x +template +static void HashCombine(std::size_t* seed, const T& v, Rest... rest) { + std::hash hasher; + *seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + *seed *= 0x00000100000001B3; + HashCombine(seed, rest...); +} + int ConvertActivationType(std::string act_type) { if (act_type == "") { return static_cast(xpu::Activation_t::LINEAR); @@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) { return -1; } +Node* FindNodeWithName(Graph* graph, std::string name) { + for (auto* node : graph->Nodes()) { + if (node->IsVar() && node->Var()->Name() == name) { + return node; + } + } + return nullptr; +} + +template +std::string IntTypeToString() { + LOG(FATAL) << "Not support type."; + return ""; +} + +template <> +std::string IntTypeToString() { + return "int16"; +} + +template +size_t HashTensor(const phi::DenseTensor& in) { + size_t ret = 0; + auto in_dims = in.dims(); + HashCombine(&ret, + phi::DataTypeToString(in.dtype()), + phi::DataLayoutToString(in.layout()), + in_dims.size()); + for (int i = 0; i < in_dims.size(); i++) { + HashCombine(&ret, in_dims[i]); + } + + auto* data = in.data(); + int64_t size = in.numel(); + for (int64_t i = 0; i < size; i++) { + HashCombine(&ret, data[i]); + } + return ret; +} + +template size_t HashTensor(const phi::DenseTensor& in); +template size_t HashTensor(const phi::DenseTensor& in); + +template +void PrepareWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* src, + Node** dst, + Node** dst_max, + bool transpose) { + auto src_name = src->Name(); + auto* src_tensor = scope->Var(src_name)->GetMutable(); + phi::DenseTensor dst_tensor; + Assign(*src_tensor, &dst_tensor); + phi::DenseTensor dst_max_tensor; + PrepareWeight(&dst_tensor, &dst_max_tensor, transpose); + + size_t dst_hash = HashTensor(dst_tensor); + size_t dst_max_hash = HashTensor(dst_max_tensor); + std::string dst_name = src_name + "_" + std::to_string(dst_hash); + std::string dst_max_name = src_name + "_max_" + std::to_string(dst_max_hash); + *dst = FindNodeWithName(graph, dst_name); + if (*dst == nullptr) { + // Create dst node + // Update dst var_desc in block + VarDesc dst_desc(dst_name); + dst_desc.SetPersistable(true); + dst_desc.SetShape(vectorize(dst_tensor.dims())); + dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype())); + *dst = graph->CreateVarNode(&dst_desc); + auto* block_dst_desc = block->Var(dst_name); + block_dst_desc->SetPersistable(dst_desc.Persistable()); + block_dst_desc->SetShape(dst_desc.GetShape()); + block_dst_desc->SetDataType(dst_desc.GetDataType()); + // Create dst_max node + // Update dst_max var_desc in block + VarDesc dst_max_desc(dst_max_name); + dst_max_desc.SetPersistable(true); + dst_max_desc.SetShape(vectorize(dst_max_tensor.dims())); + dst_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32); + *dst_max = graph->CreateVarNode(&dst_max_desc); + auto* block_dst_max_desc = block->Var(dst_max_name); + block_dst_max_desc->SetPersistable(dst_max_desc.Persistable()); + block_dst_max_desc->SetShape(dst_max_desc.GetShape()); + block_dst_max_desc->SetDataType(dst_max_desc.GetDataType()); + + // Find dst/dst_max variable in scope + auto* dst_var = scope->FindVar(dst_name); + if (dst_var == nullptr) { + // Create dst/dst_max variable/tensor + Assign(dst_tensor, scope->Var(dst_name)->GetMutable()); + Assign(dst_max_tensor, + scope->Var(dst_max_name)->GetMutable()); + } else { + // Share the same variable + PADDLE_ENFORCE_NOT_NULL( + scope->FindVar(dst_max_name), + platform::errors::Fatal( + "dst_max(%s) variable should not be nullptr if dst(%s) " + "variable is exist. (src_name is %s)", + dst_max_name, + dst_name, + src_name)); + } + } else { + *dst_max = FindNodeWithName(graph, dst_max_name); + PADDLE_ENFORCE_NOT_NULL( + *dst_max, + platform::errors::Fatal( + "dst_max(%s) variable should not be nullptr if dst(%s) " + "variable is exist. (src_name is %s)", + dst_max_name, + dst_name, + src_name)); + } +} + +template void PrepareWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* src, + Node** dst, + Node** dst_max, + bool transpose); + +void PrepareBias( + Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) { + auto src_name = src->Name(); + auto* src_tensor = scope->Var(src_name)->GetMutable(); + if (src_tensor->dtype() == phi::DataType::FLOAT32) { + *dst = src; + } + + phi::DenseTensor dst_tensor; + CastToFp32(src_tensor, &dst_tensor); + size_t dst_hash = HashTensor(dst_tensor); + std::string dst_name = src_name + "_" + std::to_string(dst_hash); + *dst = FindNodeWithName(graph, dst_name); + if (*dst == nullptr) { + // Create dst node + // Update dst var_desc in block + VarDesc dst_desc(dst_name); + dst_desc.SetPersistable(true); + dst_desc.SetShape(vectorize(dst_tensor.dims())); + dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype())); + *dst = graph->CreateVarNode(&dst_desc); + auto* block_dst_desc = block->Var(dst_name); + block_dst_desc->SetPersistable(dst_desc.Persistable()); + block_dst_desc->SetShape(dst_desc.GetShape()); + block_dst_desc->SetDataType(dst_desc.GetDataType()); + Assign(dst_tensor, scope->Var(dst_name)->GetMutable()); + } +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/pass_utils.h b/paddle/fluid/framework/ir/xpu/pass_utils.h index f4593823803..68cfb2953e1 100644 --- a/paddle/fluid/framework/ir/xpu/pass_utils.h +++ b/paddle/fluid/framework/ir/xpu/pass_utils.h @@ -14,6 +14,10 @@ #pragma once #include +#include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/ir/xpu/quant_utils.h" +#include "paddle/fluid/framework/scope.h" namespace paddle { namespace framework { @@ -45,6 +49,23 @@ namespace ir { int ConvertActivationType(std::string act_type); +Node* FindNodeWithName(Graph* graph, std::string name); + +template +size_t HashTensor(const phi::DenseTensor& in); + +template +void PrepareWeight(Graph* graph, + Scope* scope, + BlockDesc* block, + Node* src, + Node** dst, + Node** dst_max, + bool transpose); + +void PrepareBias( + Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst); + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.cc b/paddle/fluid/framework/ir/xpu/quant_utils.cc index de365f71c63..fd807384a0b 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.cc +++ b/paddle/fluid/framework/ir/xpu/quant_utils.cc @@ -207,9 +207,9 @@ void QuantFP32ToIntX(const float* src_ptr, } template -void QuantWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose) { +void PrepareWeight(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose) { // Convert fp16 to fp32 phi::DenseTensor weight_fp32; CastToFp32(weight, &weight_fp32); @@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight, QuantFP32ToIntX(weight_data, cpu_ctx->Alloc(weight), max_val, size); } -template void QuantWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +template void PrepareWeight(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose); } // namespace ir } // namespace framework diff --git a/paddle/fluid/framework/ir/xpu/quant_utils.h b/paddle/fluid/framework/ir/xpu/quant_utils.h index 57519a58432..85e9ddb1182 100644 --- a/paddle/fluid/framework/ir/xpu/quant_utils.h +++ b/paddle/fluid/framework/ir/xpu/quant_utils.h @@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); // 2. Weight data is in-place update. // 3. Generate weight max tensor template -void QuantWeight(phi::DenseTensor* weight, - phi::DenseTensor* weight_max, - bool transpose); +void PrepareWeight(phi::DenseTensor* weight, + phi::DenseTensor* weight_max, + bool transpose); } // namespace ir } // namespace framework diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 23fdaf3ddff..d43770e0ddb 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "fc_xpu_fuse_pass", "link_xpu_op_max_pass", "delete_op_device_pass", + "delete_isolated_node_pass", }); use_xpu_ = true; } diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py index b1a19b7b716..f05b93dcce2 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_link_xpu_op_max_pass.py @@ -15,6 +15,7 @@ import unittest import hypothesis.strategies as st +import numpy as np from auto_scan_test import PassAutoScanTest from program_config import OpConfig, ProgramConfig, TensorConfig @@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest): ) matmul0_y_shape = draw( st.lists( - st.integers(min_value=1, max_value=8), min_size=2, max_size=2 + st.integers(min_value=2, max_value=8), min_size=2, max_size=2 ) ) matmul0_y_shape[0] = matmul0_x_shape[-1] @@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest): # 3. matmul1 matmul1_y_shape = draw( st.lists( - st.integers(min_value=1, max_value=8), min_size=2, max_size=2 + st.integers(min_value=2, max_value=8), min_size=2, max_size=2 ) ) matmul1_y_shape[0] = matmul0_y_shape[-1] @@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest): if __name__ == "__main__": + np.random.seed(200) unittest.main() -- GitLab