未验证 提交 39a9abaa 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] support shared weight; delete isolated node (#51108)

上级 50ad760c
...@@ -220,7 +220,7 @@ if(WITH_XPU) ...@@ -220,7 +220,7 @@ if(WITH_XPU)
cc_library( cc_library(
xpu_pass_utils xpu_pass_utils
SRCS xpu/pass_utils.cc SRCS xpu/pass_utils.cc
DEPS pass) DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
...@@ -232,6 +232,8 @@ if(WITH_XPU) ...@@ -232,6 +232,8 @@ if(WITH_XPU)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif() endif()
cc_library( cc_library(
...@@ -484,3 +486,10 @@ if(WITH_MKLDNN) ...@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
SRCS mkldnn/cpu_bfloat16_pass_tester.cc SRCS mkldnn/cpu_bfloat16_pass_tester.cc
DEPS cpu_bfloat16_pass) DEPS cpu_bfloat16_pass)
endif() endif()
if(WITH_XPU)
cc_test(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
endif()
...@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass { ...@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const { void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
int found_subgraph_count = 0; int delete_counts = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue; if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue;
node->Op()->RemoveAttr("op_device"); node->Op()->RemoveAttr("op_device");
found_subgraph_count++; delete_counts++;
} }
if (found_subgraph_count > 0) { if (delete_counts > 0) {
LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs"; LOG(INFO) << "--- delete " << delete_counts << " op_device attr";
} }
} }
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
......
...@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass"}; "delete_op_device_pass"};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class DeleteIsolatedNodePass : public Pass {
protected:
void ApplyImpl(Graph* graph) const override;
private:
void CollectReservedPersistableNodeNames(
Graph* graph,
std::unordered_set<std::string>* reserved_persistable_node_names) const;
int RemoveIsolatedNodes(
Graph* graph,
const std::unordered_set<std::string>& reserved_persistable_node_names,
std::unordered_set<std::string>* delete_node_names) const;
int UpdateControlFlowOp(
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const;
const std::map<std::string, std::string> control_flow_op_input_map_{
{"while", "X"},
{"conditional_block", "Input"},
};
};
void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
PADDLE_ENFORCE(graph->IsMainGraph(),
platform::errors::PreconditionNotMet(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."));
std::unordered_set<std::string> reserved_persistable_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
CollectReservedPersistableNodeNames(graph->GetSubGraph(i),
&reserved_persistable_node_names);
}
int delete_counts = 0;
std::unordered_set<std::string> delete_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
delete_counts += RemoveIsolatedNodes(graph->GetSubGraph(i),
reserved_persistable_node_names,
&delete_node_names);
}
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " isolated nodes";
}
std::map<int, Graph*> block_id_graph_map;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto* sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsVar()) {
block_id_graph_map[node->GetVarNodeBlockId()] = sub_graph;
break;
}
}
}
int update_counts = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
update_counts += UpdateControlFlowOp(
graph->GetSubGraph(i), block_id_graph_map, delete_node_names);
}
if (update_counts > 0) {
LOG(INFO) << "--- update " << update_counts << " control flow ops";
}
}
void DeleteIsolatedNodePass::CollectReservedPersistableNodeNames(
Graph* graph,
std::unordered_set<std::string>* reserved_persistable_node_names) const {
for (auto* node : graph->Nodes()) {
if (!node->IsVar() || !node->Var()->Persistable()) continue;
for (auto* out_node : node->outputs) {
auto op_type = out_node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) {
reserved_persistable_node_names->insert(node->Var()->Name());
break;
}
}
}
}
int DeleteIsolatedNodePass::RemoveIsolatedNodes(
Graph* graph,
const std::unordered_set<std::string>& reserved_persistable_node_names,
std::unordered_set<std::string>* delete_node_names) const {
BlockDesc* block = nullptr;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
block = node->Op()->Block();
}
}
Scope& scope = graph->Get<framework::Scope>("__param_scope__");
// If graph has nodes to delete:
// 1. Clear var_desc in block
// 2. Clear tensor in variable
// 3. Clear variable in scope
int delete_node_counts = 0;
std::unordered_set<const Node*> delete_nodes;
const std::unordered_set<ir::Node*> nodes = graph->Nodes();
for (auto* node : nodes) {
if (!node->IsVar() || !node->Var()->Persistable()) continue;
auto name = node->Var()->Name();
if (reserved_persistable_node_names.count(name) > 0) continue;
delete_nodes.insert(node);
delete_node_names->insert(node->Name());
block->RemoveVar(name);
auto* var = scope.FindVar(name);
if (var != nullptr) {
var->Clear();
scope.EraseVars({name});
}
delete_node_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
return delete_node_counts;
}
int DeleteIsolatedNodePass::UpdateControlFlowOp(
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const {
int update_counts = 0;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto op_type = node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) continue;
auto in_arg_name = control_flow_op_input_map_.at(op_type);
auto in_name = node->Op()->Input(in_arg_name);
std::unordered_set<std::string> in_names_set(in_name.begin(),
in_name.end());
for (auto delete_node_name : delete_node_names) {
if (in_names_set.count(delete_node_name) > 0) {
in_names_set.erase(delete_node_name);
}
}
auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*,
node->Op()->GetAttr("sub_block"));
auto* sub_graph = block_id_graph_map.at(sub_block->ID());
std::unordered_set<std::string> sub_persistable_node_names;
CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names);
for (auto sub_name : sub_persistable_node_names) {
if (in_names_set.count(sub_name) > 0) continue;
auto* in_node = FindNodeWithName(graph, sub_name);
if (in_node == nullptr) continue;
in_names_set.insert(sub_name);
IR_NODE_LINK_TO(in_node, node);
}
std::vector<std::string> new_in_names(in_names_set.begin(),
in_names_set.end());
node->Op()->SetInput(in_arg_name, new_in_names);
update_counts++;
}
return update_counts;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_isolated_node_pass,
paddle::framework::ir::DeleteIsolatedNodePass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* data = cpu_ctx->Alloc<float>(tensor);
int64_t numel = tensor->numel();
for (int64_t i = 0; i < numel; ++i) {
data[i] = 1;
}
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "matmul0_w", {128, 128});
return param_scope;
}
int WeightNodeNum(ir::Graph* graph) {
int num = 0;
for (auto node : graph->Nodes()) {
if (node->IsVar() && node->Var()->Persistable()) {
num++;
}
}
return num;
}
int WeightTensorNum(Scope* scope) {
int num = 0;
auto vars = scope->LocalVars();
for (auto* var : vars) {
if (var->Get<phi::DenseTensor>().numel() > 0) {
num++;
}
}
return num;
}
TEST(delete_isolated_node_pass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* matmul0_x = Data(block0, "matmul0_x", {1, 128});
auto* matmul0_w = Data(block0, "matmul0_w", {128, 128}, true);
auto* matmul0_out = Data(block0, "matmul0_out", {1, 128});
OpDesc* matmul_op = block0->AppendOp();
matmul_op->SetType("matmul_v2");
matmul_op->SetInput("X", {matmul0_x->Name()});
matmul_op->SetInput("Y", {matmul0_w->Name()});
matmul_op->SetAttr("trans_x", false);
matmul_op->SetAttr("trans_y", false);
matmul_op->SetOutput("Out", {matmul0_out->Name()});
auto* while_out = Data(block0, "while_out", {1, 128});
auto* while_step_scopes = Data(block0, "while_step_scopes");
auto* while_cond = Data(block0, "while_cond");
OpDesc* while_op = block0->AppendOp();
while_op->SetType("while");
while_op->SetInput("X", {matmul0_w->Name(), matmul0_out->Name()});
while_op->SetInput("Condition", {while_cond->Name()});
while_op->SetOutput("Out", {while_out->Name()});
while_op->SetOutput("StepScopes", {while_step_scopes->Name()});
while_op->SetAttr("sub_block", {block1});
while_op->SetAttr("is_test", true);
auto* matmul1_x = Data(block1, matmul0_out->Name(), matmul0_out->GetShape());
auto* matmul1_w =
Data(block1, matmul0_w->Name(), matmul0_w->GetShape(), true);
auto* matmul1_out = Data(block1, "matmul1_out", {1, 128});
OpDesc* matmul1_op = block1->AppendOp();
matmul1_op->SetType("matmul_v2");
matmul1_op->SetInput("X", {matmul1_x->Name()});
matmul1_op->SetInput("Y", {matmul1_w->Name()});
matmul1_op->SetAttr("trans_x", false);
matmul1_op->SetAttr("trans_y", false);
matmul1_op->SetOutput("Out", {matmul1_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto* scope = CreateParamScope();
graph->Set("__param_scope__", scope);
auto pass0 = PassRegistry::Instance().Get("fc_xpu_fuse_pass");
pass0->Apply(graph.get());
pass0->Apply(graph->GetSubGraph(1));
int weight_node_num =
WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1));
PADDLE_ENFORCE_EQ(weight_node_num,
6,
platform::errors::PreconditionNotMet(
"Graph should have 6 weight node after "
"fc_xpu_fuse_pass, but actually has %d.",
weight_node_num));
auto pass1 = PassRegistry::Instance().Get("delete_isolated_node_pass");
pass1->Apply(graph.get());
weight_node_num =
WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1));
PADDLE_ENFORCE_EQ(weight_node_num,
4,
platform::errors::PreconditionNotMet(
"Graph should have 4 weight node after "
"delete_isolated_node_pass, but actually has %d.",
weight_node_num));
int weight_tensor_num = WeightTensorNum(scope);
PADDLE_ENFORCE_EQ(weight_tensor_num,
2,
platform::errors::PreconditionNotMet(
"Scope should have 2 weight tensor after "
"delete_isolated_node_pass, but actually has %d.",
weight_tensor_num));
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "while") {
auto while_in_names = node->Op()->Inputs().at("X");
PADDLE_ENFORCE_EQ(while_in_names.size(),
3,
platform::errors::PreconditionNotMet(
"While op should have 3 input after "
"delete_isolated_node_pass, but actually has %d.",
while_in_names.size()));
}
}
Scope& scope0 = graph->Get<framework::Scope>("__param_scope__");
Scope& scope1 =
graph->GetSubGraph(1)->Get<framework::Scope>("__param_scope__");
std::vector<std::string> shared_weight_names{matmul0_w->Name() + "_int16",
matmul0_w->Name() + "_max"};
for (auto name : shared_weight_names) {
auto* var0 = scope0.FindVar(name);
auto* var1 = scope1.FindVar(name);
PADDLE_ENFORCE(
var0 == var1,
platform::errors::PreconditionNotMet(
"Variables with the same name in two scopes is different."));
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_isolated_node_pass);
...@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, ...@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
auto* mul_w = pattern->NewNode(mul_w_repr()) auto* mul_w = pattern->NewNode(mul_w_repr())
->assert_is_op_input(mul_type_, "Y") ->assert_is_op_input(mul_type_, "Y")
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_has_n_outputs(1)
->assert_more([](Node* node) { ->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2; return node->Var()->GetShape().size() == 2;
}); });
...@@ -169,7 +168,7 @@ class FcXPUFusePass : public FusePassBase { ...@@ -169,7 +168,7 @@ class FcXPUFusePass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
void ApplyImpl(ir::Graph* graph, int ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
const std::string& act_type) const; const std::string& act_type) const;
...@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph); Init(name_scope_, graph);
int found_subgraph_count = 0;
for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { for (auto mul_type : {"mul", "matmul", "matmul_v2"}) {
for (auto with_bias : {true, false}) { for (auto with_bias : {true, false}) {
for (auto act_type : { for (auto act_type : {
...@@ -189,13 +190,14 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -189,13 +190,14 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"tanh", "tanh",
"", "",
}) { }) {
ApplyImpl(graph, mul_type, with_bias, act_type); found_subgraph_count += ApplyImpl(graph, mul_type, with_bias, act_type);
} }
} }
} }
AddStatis(found_subgraph_count);
} }
void FcXPUFusePass::ApplyImpl(ir::Graph* graph, int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
const std::string& act_type) const { const std::string& act_type) const {
...@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto* block = mul->Op()->Block(); auto* block = mul->Op()->Block();
auto* scope = param_scope(); auto* scope = param_scope();
auto mul_w_name = mul_w->Name();
auto mul_w_tensor =
scope->FindVar(mul_w_name)->GetMutable<phi::DenseTensor>();
// 1. Transform weight to int16/int31
// 2. Avoid transform repeatly, because weight may be shared with other ops.
// TODO(zhupengyang): support int31
std::string mul_w_max_name = mul_w_name + "_max";
Node* mul_w_max = nullptr;
if (mul_w_tensor->dtype() != phi::DataType::INT16) {
// Create weight_max node
VarDesc mul_w_max_desc(mul_w_max_name);
mul_w_max_desc.SetPersistable(true);
mul_w_max = graph->CreateVarNode(&mul_w_max_desc);
// Create weight_max var/tensor
auto mul_w_max_var = block->Var(mul_w_max_name);
mul_w_max_var->SetPersistable(true);
auto mul_w_max_tensor =
scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>();
bool transpose_w = false; bool transpose_w = false;
if (mul_type == "matmul") { if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") { } else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y")); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
} }
QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w); Node* mul_w_int16 = nullptr;
} Node* mul_w_max = nullptr;
PrepareWeight<int16_t>(
graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w);
Node* bias_fp32 = nullptr;
if (bias != nullptr) { if (bias != nullptr) {
auto* bias_tensor = PrepareBias(graph, scope, block, bias, &bias_fp32);
scope->Var(bias->Name())->GetMutable<phi::DenseTensor>();
CastToFp32(bias_tensor);
} }
std::string fc_out_name; std::string fc_out_name;
...@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
framework::OpDesc fc_xpu_op_desc(block); framework::OpDesc fc_xpu_op_desc(block);
fc_xpu_op_desc.SetType("fc_xpu"); fc_xpu_op_desc.SetType("fc_xpu");
fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); fc_xpu_op_desc.SetInput("x", {mul_x->Name()});
fc_xpu_op_desc.SetInput("w", {mul_w->Name()}); fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()});
fc_xpu_op_desc.SetInput("w_max", {mul_w_max_name}); fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()});
if (bias) { if (bias_fp32) {
fc_xpu_op_desc.SetInput("bias", {bias->Name()}); fc_xpu_op_desc.SetInput("bias", {bias_fp32->Name()});
} }
fc_xpu_op_desc.SetAttr( fc_xpu_op_desc.SetAttr(
"in_num_col_dims", "in_num_col_dims",
...@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name});
auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc);
IR_NODE_LINK_TO(mul_x, fc_xpu); IR_NODE_LINK_TO(mul_x, fc_xpu);
IR_NODE_LINK_TO(mul_w, fc_xpu); IR_NODE_LINK_TO(mul_w_int16, fc_xpu);
IR_NODE_LINK_TO(mul_w_max, fc_xpu); IR_NODE_LINK_TO(mul_w_max, fc_xpu);
SAFE_IR_NODE_LINK_TO(bias, fc_xpu); SAFE_IR_NODE_LINK_TO(bias_fp32, fc_xpu);
if (act_out) { if (act_out) {
IR_NODE_LINK_TO(fc_xpu, act_out); IR_NODE_LINK_TO(fc_xpu, act_out);
} else if (add_out) { } else if (add_out) {
...@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_subgraph_count); return found_subgraph_count;
} }
} // namespace ir } // namespace ir
......
...@@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase { ...@@ -625,16 +625,24 @@ class MultiEncoderXPUFusePass : public FusePassBase {
// 2. Concat q_w, k_w, v_w // 2. Concat q_w, k_w, v_w
// 3. Generate qkv_w_max tensor // 3. Generate qkv_w_max tensor
// 4. Quant qkv_w to int16 // 4. Quant qkv_w to int16
void PrepareQKVWeight(const phi::DenseTensor& q_w, void PrepareQKVWeight(Graph* graph,
const phi::DenseTensor& k_w, Scope* scope,
const phi::DenseTensor& v_w, BlockDesc* block,
phi::DenseTensor* qkv_w, Node* q_w,
phi::DenseTensor* qkv_w_max) const; Node* k_w,
Node* v_w,
void ConcatQKVBias(const phi::DenseTensor& q_bias, Node** qkv_w,
const phi::DenseTensor& k_bias, Node** qkv_w_max) const;
const phi::DenseTensor& v_bias,
phi::DenseTensor* qkv_bias) const; // 1. Cast bias to fp32
// 2. Concat q/k/v bias
void PrepareQKVBias(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* q_bias,
Node* k_bias,
Node* v_bias,
Node** qkv_bias) const;
const std::string name_scope_{"multi_encoder_xpu_fuse_pass"}; const std::string name_scope_{"multi_encoder_xpu_fuse_pass"};
}; };
...@@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -685,55 +693,160 @@ void MultiEncoderXPUFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(cast_mask_counts); AddStatis(cast_mask_counts);
} }
void MultiEncoderXPUFusePass::PrepareQKVWeight( void MultiEncoderXPUFusePass::PrepareQKVWeight(Graph* graph,
const phi::DenseTensor& q_w, Scope* scope,
const phi::DenseTensor& k_w, BlockDesc* block,
const phi::DenseTensor& v_w, Node* q_w,
phi::DenseTensor* qkv_w, Node* k_w,
phi::DenseTensor* qkv_w_max) const { Node* v_w,
// Transpose Node** qkv_w_int16,
phi::DenseTensor q_w_t; Node** qkv_w_max) const {
phi::DenseTensor k_w_t; phi::DenseTensor q_w_fp32_t;
phi::DenseTensor v_w_t; phi::DenseTensor k_w_fp32_t;
Assign(q_w, &q_w_t); phi::DenseTensor v_w_fp32_t;
Assign(k_w, &k_w_t); Assign(scope->Var(q_w->Name())->Get<phi::DenseTensor>(), &q_w_fp32_t);
Assign(v_w, &v_w_t); Assign(scope->Var(k_w->Name())->Get<phi::DenseTensor>(), &k_w_fp32_t);
Transpose2D(&q_w_t); Assign(scope->Var(v_w->Name())->Get<phi::DenseTensor>(), &v_w_fp32_t);
Transpose2D(&k_w_t);
Transpose2D(&v_w_t); CastToFp32(&q_w_fp32_t);
CastToFp32(&k_w_fp32_t);
// Concat CastToFp32(&v_w_fp32_t);
qkv_w->Resize(DDim(
{q_w_t.dims()[0] + k_w_t.dims()[0] + v_w_t.dims()[0], q_w_t.dims()[1]})); Transpose2D(&q_w_fp32_t);
qkv_w->set_type(q_w.type()); Transpose2D(&k_w_fp32_t);
auto* dev_ctx = static_cast<phi::CPUContext*>( Transpose2D(&v_w_fp32_t);
phi::DenseTensor qkv_w_int16_t;
phi::DenseTensor qkv_w_max_t;
qkv_w_int16_t.Resize(
DDim({q_w_fp32_t.dims()[0] + k_w_fp32_t.dims()[0] + v_w_fp32_t.dims()[0],
q_w_fp32_t.dims()[1]}));
qkv_w_int16_t.set_type(q_w_fp32_t.type());
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
std::vector<const phi::DenseTensor*> in_tensors{&q_w_t, &k_w_t, &v_w_t}; std::vector<const phi::DenseTensor*> in_tensors{
if (q_w.type() == phi::DataType::FLOAT16) { &q_w_fp32_t, &k_w_fp32_t, &v_w_fp32_t};
phi::ConcatKernel<float16>(*dev_ctx, in_tensors, 0, qkv_w); phi::ConcatKernel<float>(*cpu_ctx, in_tensors, 0, &qkv_w_int16_t);
PrepareWeight<int16_t>(&qkv_w_int16_t, &qkv_w_max_t, false);
size_t qkv_w_int16_hash = HashTensor<int16_t>(qkv_w_int16_t);
size_t qkv_w_max_hash = HashTensor<float>(qkv_w_max_t);
std::string qkv_w_int16_name = std::to_string(qkv_w_int16_hash);
std::string qkv_w_max_name = std::to_string(qkv_w_max_hash);
*qkv_w_int16 = FindNodeWithName(graph, qkv_w_int16_name);
if (*qkv_w_int16 == nullptr) {
// Create qkv_w_int16 node
// Update qkv_w_int16 var_desc in block
VarDesc qkv_w_int16_desc(qkv_w_int16_name);
qkv_w_int16_desc.SetPersistable(true);
qkv_w_int16_desc.SetShape(vectorize(qkv_w_int16_t.dims()));
qkv_w_int16_desc.SetDataType(
framework::TransToProtoVarType(qkv_w_int16_t.dtype()));
*qkv_w_int16 = graph->CreateVarNode(&qkv_w_int16_desc);
auto* block_qkv_w_int16_desc = block->Var(qkv_w_int16_name);
block_qkv_w_int16_desc->SetPersistable(qkv_w_int16_desc.Persistable());
block_qkv_w_int16_desc->SetShape(qkv_w_int16_desc.GetShape());
block_qkv_w_int16_desc->SetDataType(qkv_w_int16_desc.GetDataType());
// Create qkv_w_max node
// Update qkv_w_max var_desc in block
VarDesc qkv_w_max_desc(qkv_w_max_name);
qkv_w_max_desc.SetPersistable(true);
qkv_w_max_desc.SetShape(vectorize(qkv_w_max_t.dims()));
qkv_w_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32);
*qkv_w_max = graph->CreateVarNode(&qkv_w_max_desc);
auto* block_qkv_w_max_desc = block->Var(qkv_w_max_name);
block_qkv_w_max_desc->SetPersistable(qkv_w_max_desc.Persistable());
block_qkv_w_max_desc->SetShape(qkv_w_max_desc.GetShape());
block_qkv_w_max_desc->SetDataType(qkv_w_max_desc.GetDataType());
// Find qkv_w_int16/qkv_w_max variable in scope
auto* qkv_w_int16_var = scope->FindVar(qkv_w_int16_name);
if (qkv_w_int16_var == nullptr) {
// Create qkv_w_int16/qkv_w_max variable/tensor
Assign(qkv_w_int16_t,
scope->Var(qkv_w_int16_name)->GetMutable<phi::DenseTensor>());
Assign(qkv_w_max_t,
scope->Var(qkv_w_max_name)->GetMutable<phi::DenseTensor>());
} else { } else {
phi::ConcatKernel<float>(*dev_ctx, in_tensors, 0, qkv_w); // Share the same variable
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(qkv_w_max_name),
platform::errors::Fatal(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist.",
qkv_w_max_name,
qkv_w_int16_name));
}
} else {
*qkv_w_max = FindNodeWithName(graph, qkv_w_max_name);
PADDLE_ENFORCE_NOT_NULL(
*qkv_w_max,
platform::errors::Fatal(
"qkv_w_max(%s) variable should not be nullptr if qkv_w_int16(%s) "
"variable is exist.",
qkv_w_max_name,
qkv_w_int16_name));
} }
// Quant to int16
QuantWeight<int16_t>(qkv_w, qkv_w_max, false);
} }
void MultiEncoderXPUFusePass::ConcatQKVBias(const phi::DenseTensor& q_bias, void MultiEncoderXPUFusePass::PrepareQKVBias(Graph* graph,
const phi::DenseTensor& k_bias, Scope* scope,
const phi::DenseTensor& v_bias, BlockDesc* block,
phi::DenseTensor* qkv_bias) const { Node* q_bias,
int q_bias_size = q_bias.numel(); Node* k_bias,
qkv_bias->Resize(DDim({q_bias_size * 3})); Node* v_bias,
qkv_bias->set_type(q_bias.type()); Node** qkv_bias) const {
auto* dev_ctx = static_cast<phi::CPUContext*>( auto* q_bias_tensor =
scope->Var(q_bias->Name())->GetMutable<phi::DenseTensor>();
auto* k_bias_tensor =
scope->Var(k_bias->Name())->GetMutable<phi::DenseTensor>();
auto* v_bias_tensor =
scope->Var(v_bias->Name())->GetMutable<phi::DenseTensor>();
phi::DenseTensor q_bias_fp32_tensor;
phi::DenseTensor k_bias_fp32_tensor;
phi::DenseTensor v_bias_fp32_tensor;
CastToFp32(q_bias_tensor, &q_bias_fp32_tensor);
CastToFp32(k_bias_tensor, &k_bias_fp32_tensor);
CastToFp32(v_bias_tensor, &v_bias_fp32_tensor);
phi::DenseTensor qkv_bias_tensor;
int q_bias_fp32_size = q_bias_fp32_tensor.numel();
qkv_bias_tensor.Resize(DDim({q_bias_fp32_size * 3}));
qkv_bias_tensor.set_type(phi::DataType::FLOAT32);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* qkv_bias_data = dev_ctx->Alloc<float>(qkv_bias); auto* qkv_bias_data = cpu_ctx->Alloc<float>(&qkv_bias_tensor);
memcpy(qkv_bias_data, q_bias.data(), q_bias_size * sizeof(float)); memcpy(qkv_bias_data,
qkv_bias_data += q_bias_size; q_bias_fp32_tensor.data(),
memcpy(qkv_bias_data, k_bias.data(), q_bias_size * sizeof(float)); q_bias_fp32_size * sizeof(float));
qkv_bias_data += q_bias_size; qkv_bias_data += q_bias_fp32_size;
memcpy(qkv_bias_data, v_bias.data(), q_bias_size * sizeof(float)); memcpy(qkv_bias_data,
k_bias_fp32_tensor.data(),
q_bias_fp32_size * sizeof(float));
qkv_bias_data += q_bias_fp32_size;
memcpy(qkv_bias_data,
v_bias_fp32_tensor.data(),
q_bias_fp32_size * sizeof(float));
size_t qkv_bias_hash = HashTensor<float>(qkv_bias_tensor);
std::string qkv_bias_name = std::to_string(qkv_bias_hash);
*qkv_bias = FindNodeWithName(graph, qkv_bias_name);
if (*qkv_bias == nullptr) {
// Create qkv_bias node
// Update qkv_bias var_desc in block
VarDesc qkv_bias_desc(qkv_bias_name);
qkv_bias_desc.SetPersistable(true);
qkv_bias_desc.SetShape(vectorize(qkv_bias_tensor.dims()));
qkv_bias_desc.SetDataType(
framework::TransToProtoVarType(qkv_bias_tensor.dtype()));
*qkv_bias = graph->CreateVarNode(&qkv_bias_desc);
auto* block_qkv_bias_desc = block->Var(qkv_bias_name);
block_qkv_bias_desc->SetPersistable(qkv_bias_desc.Persistable());
block_qkv_bias_desc->SetShape(qkv_bias_desc.GetShape());
block_qkv_bias_desc->SetDataType(qkv_bias_desc.GetDataType());
Assign(qkv_bias_tensor,
scope->Var(qkv_bias_name)->GetMutable<phi::DenseTensor>());
}
} }
int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
...@@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -856,109 +969,67 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
scope->FindVar(q_matmul_w->Name())->Get<phi::DenseTensor>().dtype() == scope->FindVar(q_matmul_w->Name())->Get<phi::DenseTensor>().dtype() ==
phi::DataType::FLOAT16; phi::DataType::FLOAT16;
// Prepare q,k,v weight Node* qkv_w_int16 = nullptr;
std::string q_w_name = q_matmul_w->Name(); Node* qkv_w_max = nullptr;
std::string k_w_name = k_matmul_w->Name(); PrepareQKVWeight(graph,
std::string v_w_name = v_matmul_w->Name(); scope,
std::string qkv_w_name = q_w_name + "_" + k_w_name + "_" + v_w_name; block,
VarDesc qkv_w_desc(qkv_w_name); q_matmul_w,
qkv_w_desc.SetPersistable(true); k_matmul_w,
auto* qkv_w = graph->CreateVarNode(&qkv_w_desc); v_matmul_w,
auto* qkv_w_var = block->Var(qkv_w_name); &qkv_w_int16,
qkv_w_var->SetPersistable(true); &qkv_w_max);
std::string qkv_w_max_name = qkv_w_name + "_max";
VarDesc qkv_w_max_desc(qkv_w_max_name);
qkv_w_max_desc.SetPersistable(true);
auto* qkv_w_max = graph->CreateVarNode(&qkv_w_max_desc);
auto* qkv_w_max_var = block->Var(qkv_w_max_name);
qkv_w_max_var->SetPersistable(true);
PrepareQKVWeight(
scope->FindVar(q_w_name)->Get<phi::DenseTensor>(),
scope->FindVar(k_w_name)->Get<phi::DenseTensor>(),
scope->FindVar(v_w_name)->Get<phi::DenseTensor>(),
scope->Var(qkv_w_name)->GetMutable<phi::DenseTensor>(),
scope->Var(qkv_w_max_name)->GetMutable<phi::DenseTensor>());
// Prepare qkv_matmul_1_w, qkv_matmul_2_w, qkv_matmul_3_w
#define PREPARE_QKV_MATMUL_W(idx_) \ #define PREPARE_QKV_MATMUL_W(idx_) \
std::string qkv_matmul_##idx_##_w_name = qkv_matmul_##idx_##_w->Name(); \ Node* qkv_matmul_##idx_##_w_int16 = nullptr; \
std::string qkv_matmul_##idx_##_w_max_name = \ Node* qkv_matmul_##idx_##_w_max = nullptr; \
qkv_matmul_##idx_##_w_name + "_max"; \ PrepareWeight<int16_t>(graph, \
VarDesc qkv_matmul_##idx_##_w_max_desc(qkv_matmul_##idx_##_w_max_name); \ scope, \
qkv_matmul_##idx_##_w_max_desc.SetPersistable(true); \ block, \
auto qkv_matmul_##idx_##_w_max = \ qkv_matmul_##idx_##_w, \
graph->CreateVarNode(&qkv_matmul_##idx_##_w_max_desc); \ &qkv_matmul_##idx_##_w_int16, \
auto qkv_matmul_##idx_##_w_max_var = \ &qkv_matmul_##idx_##_w_max, \
block->Var(qkv_matmul_##idx_##_w_max_name); \ true);
qkv_matmul_##idx_##_w_max_var->SetPersistable(true); \
auto qkv_matmul_##idx_##_w_max_tensor = \
scope->Var(qkv_matmul_##idx_##_w_max_name) \
->GetMutable<phi::DenseTensor>(); \
auto qkv_matmul_##idx_##_w_tensor = \
scope->Var(qkv_matmul_##idx_##_w_name)->GetMutable<phi::DenseTensor>(); \
QuantWeight<int16_t>( \
qkv_matmul_##idx_##_w_tensor, qkv_matmul_##idx_##_w_max_tensor, true);
PREPARE_QKV_MATMUL_W(1); PREPARE_QKV_MATMUL_W(1);
PREPARE_QKV_MATMUL_W(2); PREPARE_QKV_MATMUL_W(2);
PREPARE_QKV_MATMUL_W(3); PREPARE_QKV_MATMUL_W(3);
#undef PREPARE_QKV_MATMUL_W #undef PREPARE_QKV_MATMUL_W
// Concat q_add_bias, k_add_bias, v_add_bias Node* qkv_add_bias_fp32 = nullptr;
std::string q_add_bias_name = q_add_bias->Name(); PrepareQKVBias(graph,
std::string k_add_bias_name = k_add_bias->Name(); scope,
std::string v_add_bias_name = v_add_bias->Name(); block,
std::string qkv_add_bias_name = q_add_bias,
q_add_bias_name + "_" + k_add_bias_name + "_" + v_add_bias_name; k_add_bias,
VarDesc qkv_add_bias_desc(qkv_add_bias_name); v_add_bias,
qkv_add_bias_desc.SetPersistable(true); &qkv_add_bias_fp32);
auto* qkv_add_bias = graph->CreateVarNode(&qkv_add_bias_desc);
auto* qkv_add_bias_var = block->Var(qkv_add_bias_name); Node* qkv_add_0_bias_fp32 = nullptr;
qkv_add_bias_var->SetPersistable(true); Node* qkv_add_2_bias_fp32 = nullptr;
auto* q_add_bias_tensor = Node* qkv_add_3_bias_fp32 = nullptr;
scope->FindVar(q_add_bias_name)->GetMutable<phi::DenseTensor>(); PrepareBias(graph, scope, block, qkv_add_0_bias, &qkv_add_0_bias_fp32);
auto* k_add_bias_tensor = PrepareBias(graph, scope, block, qkv_add_2_bias, &qkv_add_2_bias_fp32);
scope->FindVar(k_add_bias_name)->GetMutable<phi::DenseTensor>(); PrepareBias(graph, scope, block, qkv_add_3_bias, &qkv_add_3_bias_fp32);
auto* v_add_bias_tensor =
scope->FindVar(v_add_bias_name)->GetMutable<phi::DenseTensor>();
CastToFp32(q_add_bias_tensor);
CastToFp32(k_add_bias_tensor);
CastToFp32(v_add_bias_tensor);
ConcatQKVBias(
*q_add_bias_tensor,
*k_add_bias_tensor,
*v_add_bias_tensor,
scope->Var(qkv_add_bias_name)->GetMutable<phi::DenseTensor>());
// Prepare qkv_add_0_bias, qkv_add_2_bias, qkv_add_3_bias
auto qkv_add_0_bias_name = qkv_add_0_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_0_bias_name)->GetMutable<phi::DenseTensor>());
auto qkv_add_2_bias_name = qkv_add_2_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_2_bias_name)->GetMutable<phi::DenseTensor>());
auto qkv_add_3_bias_name = qkv_add_3_bias->Name();
CastToFp32(
scope->FindVar(qkv_add_3_bias_name)->GetMutable<phi::DenseTensor>());
// Generate single_encoder_xpu op // Generate single_encoder_xpu op
framework::OpDesc op_desc(block); framework::OpDesc op_desc(block);
op_desc.SetType("single_encoder_xpu"); op_desc.SetType("single_encoder_xpu");
op_desc.SetInput("x", {ln_0_x->Name()}); op_desc.SetInput("x", {ln_0_x->Name()});
op_desc.SetInput("fc_weight", op_desc.SetInput("fc_weight",
{qkv_w_name, {qkv_w_int16->Name(),
qkv_matmul_1_w_name, qkv_matmul_1_w_int16->Name(),
qkv_matmul_2_w_name, qkv_matmul_2_w_int16->Name(),
qkv_matmul_3_w_name}); qkv_matmul_3_w_int16->Name()});
op_desc.SetInput("fc_weight_max", op_desc.SetInput("fc_weight_max",
{qkv_w_max_name, {qkv_w_max->Name(),
qkv_matmul_1_w_max_name, qkv_matmul_1_w_max->Name(),
qkv_matmul_2_w_max_name, qkv_matmul_2_w_max->Name(),
qkv_matmul_3_w_max_name}); qkv_matmul_3_w_max->Name()});
op_desc.SetInput("fc_bias", op_desc.SetInput("fc_bias",
{qkv_add_bias_name, {qkv_add_bias_fp32->Name(),
qkv_add_0_bias_name, qkv_add_0_bias_fp32->Name(),
qkv_add_2_bias_name, qkv_add_2_bias_fp32->Name(),
qkv_add_3_bias_name}); qkv_add_3_bias_fp32->Name()});
if (norm_before) { if (norm_before) {
op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()}); op_desc.SetInput("ln_scale", {ln_0_scale->Name(), ln_1_scale->Name()});
op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()}); op_desc.SetInput("ln_bias", {ln_0_bias->Name(), ln_1_bias->Name()});
...@@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse( ...@@ -990,30 +1061,30 @@ int MultiEncoderXPUFusePass::ApplySingleEncoderXPUFuse(
} }
auto* single_encoder_xpu = graph->CreateOpNode(&op_desc); auto* single_encoder_xpu = graph->CreateOpNode(&op_desc);
// Link nodes // Link nodes
SAFE_IR_NODE_LINK_TO(ln_0_x, single_encoder_xpu); IR_NODE_LINK_TO(ln_0_x, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_w, single_encoder_xpu); IR_NODE_LINK_TO(qkv_w_int16, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu); IR_NODE_LINK_TO(qkv_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_1_w_int16, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_1_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_2_w_int16, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_2_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_3_w_int16, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu); IR_NODE_LINK_TO(qkv_matmul_3_w_max, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_bias, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_bias_fp32, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_0_bias, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_0_bias_fp32, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_2_bias, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_2_bias_fp32, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qkv_add_3_bias, single_encoder_xpu); IR_NODE_LINK_TO(qkv_add_3_bias_fp32, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_0_scale, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_0_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_0_bias, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_0_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_1_scale, single_encoder_xpu); IR_NODE_LINK_TO(ln_1_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_1_bias, single_encoder_xpu); IR_NODE_LINK_TO(ln_1_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_2_scale, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_2_scale, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(ln_2_bias, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(ln_2_bias, single_encoder_xpu);
SAFE_IR_NODE_LINK_TO(qk_add_mask, single_encoder_xpu); SAFE_IR_NODE_LINK_TO(qk_add_mask, single_encoder_xpu);
if (norm_before) { if (norm_before) {
SAFE_IR_NODE_LINK_TO(single_encoder_xpu, qkv_add_4_out); IR_NODE_LINK_TO(single_encoder_xpu, qkv_add_4_out);
} else { } else {
SAFE_IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out); IR_NODE_LINK_TO(single_encoder_xpu, ln_2_out);
} }
// Delete nodes // Delete nodes
......
...@@ -20,6 +20,18 @@ namespace paddle { ...@@ -20,6 +20,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void HashCombine(std::size_t* seed) {}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template <typename T, typename... Rest>
static void HashCombine(std::size_t* seed, const T& v, Rest... rest) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
*seed *= 0x00000100000001B3;
HashCombine(seed, rest...);
}
int ConvertActivationType(std::string act_type) { int ConvertActivationType(std::string act_type) {
if (act_type == "") { if (act_type == "") {
return static_cast<int>(xpu::Activation_t::LINEAR); return static_cast<int>(xpu::Activation_t::LINEAR);
...@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) { ...@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
return -1; return -1;
} }
Node* FindNodeWithName(Graph* graph, std::string name) {
for (auto* node : graph->Nodes()) {
if (node->IsVar() && node->Var()->Name() == name) {
return node;
}
}
return nullptr;
}
template <typename T>
std::string IntTypeToString() {
LOG(FATAL) << "Not support type.";
return "";
}
template <>
std::string IntTypeToString<int16_t>() {
return "int16";
}
template <typename T>
size_t HashTensor(const phi::DenseTensor& in) {
size_t ret = 0;
auto in_dims = in.dims();
HashCombine(&ret,
phi::DataTypeToString(in.dtype()),
phi::DataLayoutToString(in.layout()),
in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
HashCombine(&ret, in_dims[i]);
}
auto* data = in.data<T>();
int64_t size = in.numel();
for (int64_t i = 0; i < size; i++) {
HashCombine(&ret, data[i]);
}
return ret;
}
template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
template size_t HashTensor<float>(const phi::DenseTensor& in);
template <typename T>
void PrepareWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose) {
auto src_name = src->Name();
auto* src_tensor = scope->Var(src_name)->GetMutable<phi::DenseTensor>();
phi::DenseTensor dst_tensor;
Assign(*src_tensor, &dst_tensor);
phi::DenseTensor dst_max_tensor;
PrepareWeight<T>(&dst_tensor, &dst_max_tensor, transpose);
size_t dst_hash = HashTensor<T>(dst_tensor);
size_t dst_max_hash = HashTensor<float>(dst_max_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
std::string dst_max_name = src_name + "_max_" + std::to_string(dst_max_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
// Update dst var_desc in block
VarDesc dst_desc(dst_name);
dst_desc.SetPersistable(true);
dst_desc.SetShape(vectorize(dst_tensor.dims()));
dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype()));
*dst = graph->CreateVarNode(&dst_desc);
auto* block_dst_desc = block->Var(dst_name);
block_dst_desc->SetPersistable(dst_desc.Persistable());
block_dst_desc->SetShape(dst_desc.GetShape());
block_dst_desc->SetDataType(dst_desc.GetDataType());
// Create dst_max node
// Update dst_max var_desc in block
VarDesc dst_max_desc(dst_max_name);
dst_max_desc.SetPersistable(true);
dst_max_desc.SetShape(vectorize(dst_max_tensor.dims()));
dst_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32);
*dst_max = graph->CreateVarNode(&dst_max_desc);
auto* block_dst_max_desc = block->Var(dst_max_name);
block_dst_max_desc->SetPersistable(dst_max_desc.Persistable());
block_dst_max_desc->SetShape(dst_max_desc.GetShape());
block_dst_max_desc->SetDataType(dst_max_desc.GetDataType());
// Find dst/dst_max variable in scope
auto* dst_var = scope->FindVar(dst_name);
if (dst_var == nullptr) {
// Create dst/dst_max variable/tensor
Assign(dst_tensor, scope->Var(dst_name)->GetMutable<phi::DenseTensor>());
Assign(dst_max_tensor,
scope->Var(dst_max_name)->GetMutable<phi::DenseTensor>());
} else {
// Share the same variable
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(dst_max_name),
platform::errors::Fatal(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)",
dst_max_name,
dst_name,
src_name));
}
} else {
*dst_max = FindNodeWithName(graph, dst_max_name);
PADDLE_ENFORCE_NOT_NULL(
*dst_max,
platform::errors::Fatal(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)",
dst_max_name,
dst_name,
src_name));
}
}
template void PrepareWeight<int16_t>(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) {
auto src_name = src->Name();
auto* src_tensor = scope->Var(src_name)->GetMutable<phi::DenseTensor>();
if (src_tensor->dtype() == phi::DataType::FLOAT32) {
*dst = src;
}
phi::DenseTensor dst_tensor;
CastToFp32(src_tensor, &dst_tensor);
size_t dst_hash = HashTensor<float>(dst_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
// Update dst var_desc in block
VarDesc dst_desc(dst_name);
dst_desc.SetPersistable(true);
dst_desc.SetShape(vectorize(dst_tensor.dims()));
dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype()));
*dst = graph->CreateVarNode(&dst_desc);
auto* block_dst_desc = block->Var(dst_name);
block_dst_desc->SetPersistable(dst_desc.Persistable());
block_dst_desc->SetShape(dst_desc.GetShape());
block_dst_desc->SetDataType(dst_desc.GetDataType());
Assign(dst_tensor, scope->Var(dst_name)->GetMutable<phi::DenseTensor>());
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -45,6 +49,23 @@ namespace ir { ...@@ -45,6 +49,23 @@ namespace ir {
int ConvertActivationType(std::string act_type); int ConvertActivationType(std::string act_type);
Node* FindNodeWithName(Graph* graph, std::string name);
template <typename T>
size_t HashTensor(const phi::DenseTensor& in);
template <typename T>
void PrepareWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -207,7 +207,7 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr, ...@@ -207,7 +207,7 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
} }
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose) { bool transpose) {
// Convert fp16 to fp32 // Convert fp16 to fp32
...@@ -249,7 +249,7 @@ void QuantWeight(phi::DenseTensor* weight, ...@@ -249,7 +249,7 @@ void QuantWeight(phi::DenseTensor* weight,
QuantFP32ToIntX(weight_data, cpu_ctx->Alloc<T>(weight), max_val, size); QuantFP32ToIntX(weight_data, cpu_ctx->Alloc<T>(weight), max_val, size);
} }
template void QuantWeight<int16_t>(phi::DenseTensor* weight, template void PrepareWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose); bool transpose);
......
...@@ -29,7 +29,7 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); ...@@ -29,7 +29,7 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 2. Weight data is in-place update. // 2. Weight data is in-place update.
// 3. Generate weight max tensor // 3. Generate weight max tensor
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose); bool transpose);
......
...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
"delete_op_device_pass", "delete_op_device_pass",
"delete_isolated_node_pass",
}); });
use_xpu_ = true; use_xpu_ = true;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import hypothesis.strategies as st import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
...@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
) )
matmul0_y_shape = draw( matmul0_y_shape = draw(
st.lists( st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2 st.integers(min_value=2, max_value=8), min_size=2, max_size=2
) )
) )
matmul0_y_shape[0] = matmul0_x_shape[-1] matmul0_y_shape[0] = matmul0_x_shape[-1]
...@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
# 3. matmul1 # 3. matmul1
matmul1_y_shape = draw( matmul1_y_shape = draw(
st.lists( st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2 st.integers(min_value=2, max_value=8), min_size=2, max_size=2
) )
) )
matmul1_y_shape[0] = matmul0_y_shape[-1] matmul1_y_shape[0] = matmul0_y_shape[-1]
...@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(200)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册