未验证 提交 39a9abaa 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] support shared weight; delete isolated node (#51108)

上级 50ad760c
...@@ -220,7 +220,7 @@ if(WITH_XPU) ...@@ -220,7 +220,7 @@ if(WITH_XPU)
cc_library( cc_library(
xpu_pass_utils xpu_pass_utils
SRCS xpu/pass_utils.cc SRCS xpu/pass_utils.cc
DEPS pass) DEPS pass xpu_quant_utils)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils) set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
...@@ -232,6 +232,8 @@ if(WITH_XPU) ...@@ -232,6 +232,8 @@ if(WITH_XPU)
pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS pass_library(generate_sequence_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(link_xpu_op_max_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_isolated_node_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif() endif()
cc_library( cc_library(
...@@ -484,3 +486,10 @@ if(WITH_MKLDNN) ...@@ -484,3 +486,10 @@ if(WITH_MKLDNN)
SRCS mkldnn/cpu_bfloat16_pass_tester.cc SRCS mkldnn/cpu_bfloat16_pass_tester.cc
DEPS cpu_bfloat16_pass) DEPS cpu_bfloat16_pass)
endif() endif()
if(WITH_XPU)
cc_test(
test_delete_isolated_node_pass
SRCS xpu/delete_isolated_node_pass_test.cc
DEPS delete_isolated_node_pass)
endif()
...@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass { ...@@ -39,14 +39,14 @@ class DeleteOpDevicePass : public Pass {
void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const { void DeleteOpDevicePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
int found_subgraph_count = 0; int delete_counts = 0;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue; if (!node->IsOp() || !node->Op()->HasAttr("op_device")) continue;
node->Op()->RemoveAttr("op_device"); node->Op()->RemoveAttr("op_device");
found_subgraph_count++; delete_counts++;
} }
if (found_subgraph_count > 0) { if (delete_counts > 0) {
LOG(INFO) << "--- detected " << found_subgraph_count << " subgraphs"; LOG(INFO) << "--- delete " << delete_counts << " op_device attr";
} }
} }
......
...@@ -13,8 +13,7 @@ ...@@ -13,8 +13,7 @@
// limitations under the License. // limitations under the License.
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/delete_dropout_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle { namespace paddle {
......
...@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -49,6 +49,7 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass"}; "delete_op_device_pass"};
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class DeleteIsolatedNodePass : public Pass {
protected:
void ApplyImpl(Graph* graph) const override;
private:
void CollectReservedPersistableNodeNames(
Graph* graph,
std::unordered_set<std::string>* reserved_persistable_node_names) const;
int RemoveIsolatedNodes(
Graph* graph,
const std::unordered_set<std::string>& reserved_persistable_node_names,
std::unordered_set<std::string>* delete_node_names) const;
int UpdateControlFlowOp(
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const;
const std::map<std::string, std::string> control_flow_op_input_map_{
{"while", "X"},
{"conditional_block", "Input"},
};
};
void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
PADDLE_ENFORCE(graph->IsMainGraph(),
platform::errors::PreconditionNotMet(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."));
std::unordered_set<std::string> reserved_persistable_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
CollectReservedPersistableNodeNames(graph->GetSubGraph(i),
&reserved_persistable_node_names);
}
int delete_counts = 0;
std::unordered_set<std::string> delete_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
delete_counts += RemoveIsolatedNodes(graph->GetSubGraph(i),
reserved_persistable_node_names,
&delete_node_names);
}
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " isolated nodes";
}
std::map<int, Graph*> block_id_graph_map;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto* sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsVar()) {
block_id_graph_map[node->GetVarNodeBlockId()] = sub_graph;
break;
}
}
}
int update_counts = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
update_counts += UpdateControlFlowOp(
graph->GetSubGraph(i), block_id_graph_map, delete_node_names);
}
if (update_counts > 0) {
LOG(INFO) << "--- update " << update_counts << " control flow ops";
}
}
void DeleteIsolatedNodePass::CollectReservedPersistableNodeNames(
Graph* graph,
std::unordered_set<std::string>* reserved_persistable_node_names) const {
for (auto* node : graph->Nodes()) {
if (!node->IsVar() || !node->Var()->Persistable()) continue;
for (auto* out_node : node->outputs) {
auto op_type = out_node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) {
reserved_persistable_node_names->insert(node->Var()->Name());
break;
}
}
}
}
int DeleteIsolatedNodePass::RemoveIsolatedNodes(
Graph* graph,
const std::unordered_set<std::string>& reserved_persistable_node_names,
std::unordered_set<std::string>* delete_node_names) const {
BlockDesc* block = nullptr;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
block = node->Op()->Block();
}
}
Scope& scope = graph->Get<framework::Scope>("__param_scope__");
// If graph has nodes to delete:
// 1. Clear var_desc in block
// 2. Clear tensor in variable
// 3. Clear variable in scope
int delete_node_counts = 0;
std::unordered_set<const Node*> delete_nodes;
const std::unordered_set<ir::Node*> nodes = graph->Nodes();
for (auto* node : nodes) {
if (!node->IsVar() || !node->Var()->Persistable()) continue;
auto name = node->Var()->Name();
if (reserved_persistable_node_names.count(name) > 0) continue;
delete_nodes.insert(node);
delete_node_names->insert(node->Name());
block->RemoveVar(name);
auto* var = scope.FindVar(name);
if (var != nullptr) {
var->Clear();
scope.EraseVars({name});
}
delete_node_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
return delete_node_counts;
}
int DeleteIsolatedNodePass::UpdateControlFlowOp(
Graph* graph,
const std::map<int, Graph*>& block_id_graph_map,
const std::unordered_set<std::string>& delete_node_names) const {
int update_counts = 0;
for (auto* node : graph->Nodes()) {
if (!node->IsOp()) continue;
auto op_type = node->Op()->Type();
if (control_flow_op_input_map_.count(op_type) == 0) continue;
auto in_arg_name = control_flow_op_input_map_.at(op_type);
auto in_name = node->Op()->Input(in_arg_name);
std::unordered_set<std::string> in_names_set(in_name.begin(),
in_name.end());
for (auto delete_node_name : delete_node_names) {
if (in_names_set.count(delete_node_name) > 0) {
in_names_set.erase(delete_node_name);
}
}
auto* sub_block = PADDLE_GET_CONST(framework::BlockDesc*,
node->Op()->GetAttr("sub_block"));
auto* sub_graph = block_id_graph_map.at(sub_block->ID());
std::unordered_set<std::string> sub_persistable_node_names;
CollectReservedPersistableNodeNames(sub_graph, &sub_persistable_node_names);
for (auto sub_name : sub_persistable_node_names) {
if (in_names_set.count(sub_name) > 0) continue;
auto* in_node = FindNodeWithName(graph, sub_name);
if (in_node == nullptr) continue;
in_names_set.insert(sub_name);
IR_NODE_LINK_TO(in_node, node);
}
std::vector<std::string> new_in_names(in_names_set.begin(),
in_names_set.end());
node->Op()->SetInput(in_arg_name, new_in_names);
update_counts++;
}
return update_counts;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_isolated_node_pass,
paddle::framework::ir::DeleteIsolatedNodePass);
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
auto* data = cpu_ctx->Alloc<float>(tensor);
int64_t numel = tensor->numel();
for (int64_t i = 0; i < numel; ++i) {
data[i] = 1;
}
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "matmul0_w", {128, 128});
return param_scope;
}
int WeightNodeNum(ir::Graph* graph) {
int num = 0;
for (auto node : graph->Nodes()) {
if (node->IsVar() && node->Var()->Persistable()) {
num++;
}
}
return num;
}
int WeightTensorNum(Scope* scope) {
int num = 0;
auto vars = scope->LocalVars();
for (auto* var : vars) {
if (var->Get<phi::DenseTensor>().numel() > 0) {
num++;
}
}
return num;
}
TEST(delete_isolated_node_pass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* matmul0_x = Data(block0, "matmul0_x", {1, 128});
auto* matmul0_w = Data(block0, "matmul0_w", {128, 128}, true);
auto* matmul0_out = Data(block0, "matmul0_out", {1, 128});
OpDesc* matmul_op = block0->AppendOp();
matmul_op->SetType("matmul_v2");
matmul_op->SetInput("X", {matmul0_x->Name()});
matmul_op->SetInput("Y", {matmul0_w->Name()});
matmul_op->SetAttr("trans_x", false);
matmul_op->SetAttr("trans_y", false);
matmul_op->SetOutput("Out", {matmul0_out->Name()});
auto* while_out = Data(block0, "while_out", {1, 128});
auto* while_step_scopes = Data(block0, "while_step_scopes");
auto* while_cond = Data(block0, "while_cond");
OpDesc* while_op = block0->AppendOp();
while_op->SetType("while");
while_op->SetInput("X", {matmul0_w->Name(), matmul0_out->Name()});
while_op->SetInput("Condition", {while_cond->Name()});
while_op->SetOutput("Out", {while_out->Name()});
while_op->SetOutput("StepScopes", {while_step_scopes->Name()});
while_op->SetAttr("sub_block", {block1});
while_op->SetAttr("is_test", true);
auto* matmul1_x = Data(block1, matmul0_out->Name(), matmul0_out->GetShape());
auto* matmul1_w =
Data(block1, matmul0_w->Name(), matmul0_w->GetShape(), true);
auto* matmul1_out = Data(block1, "matmul1_out", {1, 128});
OpDesc* matmul1_op = block1->AppendOp();
matmul1_op->SetType("matmul_v2");
matmul1_op->SetInput("X", {matmul1_x->Name()});
matmul1_op->SetInput("Y", {matmul1_w->Name()});
matmul1_op->SetAttr("trans_x", false);
matmul1_op->SetAttr("trans_y", false);
matmul1_op->SetOutput("Out", {matmul1_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto* scope = CreateParamScope();
graph->Set("__param_scope__", scope);
auto pass0 = PassRegistry::Instance().Get("fc_xpu_fuse_pass");
pass0->Apply(graph.get());
pass0->Apply(graph->GetSubGraph(1));
int weight_node_num =
WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1));
PADDLE_ENFORCE_EQ(weight_node_num,
6,
platform::errors::PreconditionNotMet(
"Graph should have 6 weight node after "
"fc_xpu_fuse_pass, but actually has %d.",
weight_node_num));
auto pass1 = PassRegistry::Instance().Get("delete_isolated_node_pass");
pass1->Apply(graph.get());
weight_node_num =
WeightNodeNum(graph.get()) + WeightNodeNum(graph->GetSubGraph(1));
PADDLE_ENFORCE_EQ(weight_node_num,
4,
platform::errors::PreconditionNotMet(
"Graph should have 4 weight node after "
"delete_isolated_node_pass, but actually has %d.",
weight_node_num));
int weight_tensor_num = WeightTensorNum(scope);
PADDLE_ENFORCE_EQ(weight_tensor_num,
2,
platform::errors::PreconditionNotMet(
"Scope should have 2 weight tensor after "
"delete_isolated_node_pass, but actually has %d.",
weight_tensor_num));
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()->Type() == "while") {
auto while_in_names = node->Op()->Inputs().at("X");
PADDLE_ENFORCE_EQ(while_in_names.size(),
3,
platform::errors::PreconditionNotMet(
"While op should have 3 input after "
"delete_isolated_node_pass, but actually has %d.",
while_in_names.size()));
}
}
Scope& scope0 = graph->Get<framework::Scope>("__param_scope__");
Scope& scope1 =
graph->GetSubGraph(1)->Get<framework::Scope>("__param_scope__");
std::vector<std::string> shared_weight_names{matmul0_w->Name() + "_int16",
matmul0_w->Name() + "_max"};
for (auto name : shared_weight_names) {
auto* var0 = scope0.FindVar(name);
auto* var1 = scope1.FindVar(name);
PADDLE_ENFORCE(
var0 == var1,
platform::errors::PreconditionNotMet(
"Variables with the same name in two scopes is different."));
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_isolated_node_pass);
...@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern, ...@@ -76,7 +76,6 @@ FcXPUPattern::FcXPUPattern(PDPattern* pattern,
auto* mul_w = pattern->NewNode(mul_w_repr()) auto* mul_w = pattern->NewNode(mul_w_repr())
->assert_is_op_input(mul_type_, "Y") ->assert_is_op_input(mul_type_, "Y")
->assert_is_persistable_var() ->assert_is_persistable_var()
->assert_has_n_outputs(1)
->assert_more([](Node* node) { ->assert_more([](Node* node) {
return node->Var()->GetShape().size() == 2; return node->Var()->GetShape().size() == 2;
}); });
...@@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase { ...@@ -169,10 +168,10 @@ class FcXPUFusePass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
void ApplyImpl(ir::Graph* graph, int ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
const std::string& act_type) const; const std::string& act_type) const;
const std::string name_scope_{"fc_xpu_fuse_pass"}; const std::string name_scope_{"fc_xpu_fuse_pass"};
}; };
...@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -181,6 +180,8 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null.")); graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph); Init(name_scope_, graph);
int found_subgraph_count = 0;
for (auto mul_type : {"mul", "matmul", "matmul_v2"}) { for (auto mul_type : {"mul", "matmul", "matmul_v2"}) {
for (auto with_bias : {true, false}) { for (auto with_bias : {true, false}) {
for (auto act_type : { for (auto act_type : {
...@@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -189,16 +190,17 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph) const {
"tanh", "tanh",
"", "",
}) { }) {
ApplyImpl(graph, mul_type, with_bias, act_type); found_subgraph_count += ApplyImpl(graph, mul_type, with_bias, act_type);
} }
} }
} }
AddStatis(found_subgraph_count);
} }
void FcXPUFusePass::ApplyImpl(ir::Graph* graph, int FcXPUFusePass::ApplyImpl(ir::Graph* graph,
const std::string& mul_type, const std::string& mul_type,
bool with_bias, bool with_bias,
const std::string& act_type) const { const std::string& act_type) const {
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::FcXPUPattern pattern( patterns::FcXPUPattern pattern(
gpd.mutable_pattern(), name_scope_, mul_type, with_bias, act_type); gpd.mutable_pattern(), name_scope_, mul_type, with_bias, act_type);
...@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -219,37 +221,20 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
auto* block = mul->Op()->Block(); auto* block = mul->Op()->Block();
auto* scope = param_scope(); auto* scope = param_scope();
auto mul_w_name = mul_w->Name(); bool transpose_w = false;
auto mul_w_tensor = if (mul_type == "matmul") {
scope->FindVar(mul_w_name)->GetMutable<phi::DenseTensor>(); transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
// 1. Transform weight to int16/int31 } else if (mul_type == "matmul_v2") {
// 2. Avoid transform repeatly, because weight may be shared with other ops. transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
// TODO(zhupengyang): support int31
std::string mul_w_max_name = mul_w_name + "_max";
Node* mul_w_max = nullptr;
if (mul_w_tensor->dtype() != phi::DataType::INT16) {
// Create weight_max node
VarDesc mul_w_max_desc(mul_w_max_name);
mul_w_max_desc.SetPersistable(true);
mul_w_max = graph->CreateVarNode(&mul_w_max_desc);
// Create weight_max var/tensor
auto mul_w_max_var = block->Var(mul_w_max_name);
mul_w_max_var->SetPersistable(true);
auto mul_w_max_tensor =
scope->Var(mul_w_max_name)->GetMutable<phi::DenseTensor>();
bool transpose_w = false;
if (mul_type == "matmul") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("transpose_Y"));
} else if (mul_type == "matmul_v2") {
transpose_w = PADDLE_GET_CONST(bool, mul->Op()->GetAttr("trans_y"));
}
QuantWeight<int16_t>(mul_w_tensor, mul_w_max_tensor, !transpose_w);
} }
Node* mul_w_int16 = nullptr;
Node* mul_w_max = nullptr;
PrepareWeight<int16_t>(
graph, scope, block, mul_w, &mul_w_int16, &mul_w_max, !transpose_w);
Node* bias_fp32 = nullptr;
if (bias != nullptr) { if (bias != nullptr) {
auto* bias_tensor = PrepareBias(graph, scope, block, bias, &bias_fp32);
scope->Var(bias->Name())->GetMutable<phi::DenseTensor>();
CastToFp32(bias_tensor);
} }
std::string fc_out_name; std::string fc_out_name;
...@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -268,10 +253,10 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
framework::OpDesc fc_xpu_op_desc(block); framework::OpDesc fc_xpu_op_desc(block);
fc_xpu_op_desc.SetType("fc_xpu"); fc_xpu_op_desc.SetType("fc_xpu");
fc_xpu_op_desc.SetInput("x", {mul_x->Name()}); fc_xpu_op_desc.SetInput("x", {mul_x->Name()});
fc_xpu_op_desc.SetInput("w", {mul_w->Name()}); fc_xpu_op_desc.SetInput("w", {mul_w_int16->Name()});
fc_xpu_op_desc.SetInput("w_max", {mul_w_max_name}); fc_xpu_op_desc.SetInput("w_max", {mul_w_max->Name()});
if (bias) { if (bias_fp32) {
fc_xpu_op_desc.SetInput("bias", {bias->Name()}); fc_xpu_op_desc.SetInput("bias", {bias_fp32->Name()});
} }
fc_xpu_op_desc.SetAttr( fc_xpu_op_desc.SetAttr(
"in_num_col_dims", "in_num_col_dims",
...@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -306,9 +291,9 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name}); fc_xpu_op_desc.SetOutput("out_max", {fc_out_max_name});
auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc); auto* fc_xpu = graph->CreateOpNode(&fc_xpu_op_desc);
IR_NODE_LINK_TO(mul_x, fc_xpu); IR_NODE_LINK_TO(mul_x, fc_xpu);
IR_NODE_LINK_TO(mul_w, fc_xpu); IR_NODE_LINK_TO(mul_w_int16, fc_xpu);
IR_NODE_LINK_TO(mul_w_max, fc_xpu); IR_NODE_LINK_TO(mul_w_max, fc_xpu);
SAFE_IR_NODE_LINK_TO(bias, fc_xpu); SAFE_IR_NODE_LINK_TO(bias_fp32, fc_xpu);
if (act_out) { if (act_out) {
IR_NODE_LINK_TO(fc_xpu, act_out); IR_NODE_LINK_TO(fc_xpu, act_out);
} else if (add_out) { } else if (add_out) {
...@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -334,7 +319,7 @@ void FcXPUFusePass::ApplyImpl(ir::Graph* graph,
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_subgraph_count); return found_subgraph_count;
} }
} // namespace ir } // namespace ir
......
...@@ -20,6 +20,18 @@ namespace paddle { ...@@ -20,6 +20,18 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
static void HashCombine(std::size_t* seed) {}
// combine hash value
// https://stackoverflow.com/questions/2590677/how-do-i-combine-hash-values-in-c0x
template <typename T, typename... Rest>
static void HashCombine(std::size_t* seed, const T& v, Rest... rest) {
std::hash<T> hasher;
*seed ^= hasher(v) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
*seed *= 0x00000100000001B3;
HashCombine(seed, rest...);
}
int ConvertActivationType(std::string act_type) { int ConvertActivationType(std::string act_type) {
if (act_type == "") { if (act_type == "") {
return static_cast<int>(xpu::Activation_t::LINEAR); return static_cast<int>(xpu::Activation_t::LINEAR);
...@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) { ...@@ -50,6 +62,161 @@ int ConvertActivationType(std::string act_type) {
return -1; return -1;
} }
Node* FindNodeWithName(Graph* graph, std::string name) {
for (auto* node : graph->Nodes()) {
if (node->IsVar() && node->Var()->Name() == name) {
return node;
}
}
return nullptr;
}
template <typename T>
std::string IntTypeToString() {
LOG(FATAL) << "Not support type.";
return "";
}
template <>
std::string IntTypeToString<int16_t>() {
return "int16";
}
template <typename T>
size_t HashTensor(const phi::DenseTensor& in) {
size_t ret = 0;
auto in_dims = in.dims();
HashCombine(&ret,
phi::DataTypeToString(in.dtype()),
phi::DataLayoutToString(in.layout()),
in_dims.size());
for (int i = 0; i < in_dims.size(); i++) {
HashCombine(&ret, in_dims[i]);
}
auto* data = in.data<T>();
int64_t size = in.numel();
for (int64_t i = 0; i < size; i++) {
HashCombine(&ret, data[i]);
}
return ret;
}
template size_t HashTensor<int16_t>(const phi::DenseTensor& in);
template size_t HashTensor<float>(const phi::DenseTensor& in);
template <typename T>
void PrepareWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose) {
auto src_name = src->Name();
auto* src_tensor = scope->Var(src_name)->GetMutable<phi::DenseTensor>();
phi::DenseTensor dst_tensor;
Assign(*src_tensor, &dst_tensor);
phi::DenseTensor dst_max_tensor;
PrepareWeight<T>(&dst_tensor, &dst_max_tensor, transpose);
size_t dst_hash = HashTensor<T>(dst_tensor);
size_t dst_max_hash = HashTensor<float>(dst_max_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
std::string dst_max_name = src_name + "_max_" + std::to_string(dst_max_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
// Update dst var_desc in block
VarDesc dst_desc(dst_name);
dst_desc.SetPersistable(true);
dst_desc.SetShape(vectorize(dst_tensor.dims()));
dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype()));
*dst = graph->CreateVarNode(&dst_desc);
auto* block_dst_desc = block->Var(dst_name);
block_dst_desc->SetPersistable(dst_desc.Persistable());
block_dst_desc->SetShape(dst_desc.GetShape());
block_dst_desc->SetDataType(dst_desc.GetDataType());
// Create dst_max node
// Update dst_max var_desc in block
VarDesc dst_max_desc(dst_max_name);
dst_max_desc.SetPersistable(true);
dst_max_desc.SetShape(vectorize(dst_max_tensor.dims()));
dst_max_desc.SetDataType(proto::VarType::Type::VarType_Type_FP32);
*dst_max = graph->CreateVarNode(&dst_max_desc);
auto* block_dst_max_desc = block->Var(dst_max_name);
block_dst_max_desc->SetPersistable(dst_max_desc.Persistable());
block_dst_max_desc->SetShape(dst_max_desc.GetShape());
block_dst_max_desc->SetDataType(dst_max_desc.GetDataType());
// Find dst/dst_max variable in scope
auto* dst_var = scope->FindVar(dst_name);
if (dst_var == nullptr) {
// Create dst/dst_max variable/tensor
Assign(dst_tensor, scope->Var(dst_name)->GetMutable<phi::DenseTensor>());
Assign(dst_max_tensor,
scope->Var(dst_max_name)->GetMutable<phi::DenseTensor>());
} else {
// Share the same variable
PADDLE_ENFORCE_NOT_NULL(
scope->FindVar(dst_max_name),
platform::errors::Fatal(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)",
dst_max_name,
dst_name,
src_name));
}
} else {
*dst_max = FindNodeWithName(graph, dst_max_name);
PADDLE_ENFORCE_NOT_NULL(
*dst_max,
platform::errors::Fatal(
"dst_max(%s) variable should not be nullptr if dst(%s) "
"variable is exist. (src_name is %s)",
dst_max_name,
dst_name,
src_name));
}
}
template void PrepareWeight<int16_t>(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst) {
auto src_name = src->Name();
auto* src_tensor = scope->Var(src_name)->GetMutable<phi::DenseTensor>();
if (src_tensor->dtype() == phi::DataType::FLOAT32) {
*dst = src;
}
phi::DenseTensor dst_tensor;
CastToFp32(src_tensor, &dst_tensor);
size_t dst_hash = HashTensor<float>(dst_tensor);
std::string dst_name = src_name + "_" + std::to_string(dst_hash);
*dst = FindNodeWithName(graph, dst_name);
if (*dst == nullptr) {
// Create dst node
// Update dst var_desc in block
VarDesc dst_desc(dst_name);
dst_desc.SetPersistable(true);
dst_desc.SetShape(vectorize(dst_tensor.dims()));
dst_desc.SetDataType(framework::TransToProtoVarType(dst_tensor.dtype()));
*dst = graph->CreateVarNode(&dst_desc);
auto* block_dst_desc = block->Var(dst_name);
block_dst_desc->SetPersistable(dst_desc.Persistable());
block_dst_desc->SetShape(dst_desc.GetShape());
block_dst_desc->SetDataType(dst_desc.GetDataType());
Assign(dst_tensor, scope->Var(dst_name)->GetMutable<phi::DenseTensor>());
}
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,10 @@ ...@@ -14,6 +14,10 @@
#pragma once #pragma once
#include <string> #include <string>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -45,6 +49,23 @@ namespace ir { ...@@ -45,6 +49,23 @@ namespace ir {
int ConvertActivationType(std::string act_type); int ConvertActivationType(std::string act_type);
Node* FindNodeWithName(Graph* graph, std::string name);
template <typename T>
size_t HashTensor(const phi::DenseTensor& in);
template <typename T>
void PrepareWeight(Graph* graph,
Scope* scope,
BlockDesc* block,
Node* src,
Node** dst,
Node** dst_max,
bool transpose);
void PrepareBias(
Graph* graph, Scope* scope, BlockDesc* block, Node* src, Node** dst);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -207,9 +207,9 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr, ...@@ -207,9 +207,9 @@ void QuantFP32ToIntX<int16_t>(const float* src_ptr,
} }
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose) { bool transpose) {
// Convert fp16 to fp32 // Convert fp16 to fp32
phi::DenseTensor weight_fp32; phi::DenseTensor weight_fp32;
CastToFp32(weight, &weight_fp32); CastToFp32(weight, &weight_fp32);
...@@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight, ...@@ -249,9 +249,9 @@ void QuantWeight(phi::DenseTensor* weight,
QuantFP32ToIntX(weight_data, cpu_ctx->Alloc<T>(weight), max_val, size); QuantFP32ToIntX(weight_data, cpu_ctx->Alloc<T>(weight), max_val, size);
} }
template void QuantWeight<int16_t>(phi::DenseTensor* weight, template void PrepareWeight<int16_t>(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose); bool transpose);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr); ...@@ -29,9 +29,9 @@ void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 2. Weight data is in-place update. // 2. Weight data is in-place update.
// 3. Generate weight max tensor // 3. Generate weight max tensor
template <typename T> template <typename T>
void QuantWeight(phi::DenseTensor* weight, void PrepareWeight(phi::DenseTensor* weight,
phi::DenseTensor* weight_max, phi::DenseTensor* weight_max,
bool transpose); bool transpose);
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"link_xpu_op_max_pass", "link_xpu_op_max_pass",
"delete_op_device_pass", "delete_op_device_pass",
"delete_isolated_node_pass",
}); });
use_xpu_ = true; use_xpu_ = true;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import hypothesis.strategies as st import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
...@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -33,7 +34,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
) )
matmul0_y_shape = draw( matmul0_y_shape = draw(
st.lists( st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2 st.integers(min_value=2, max_value=8), min_size=2, max_size=2
) )
) )
matmul0_y_shape[0] = matmul0_x_shape[-1] matmul0_y_shape[0] = matmul0_x_shape[-1]
...@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -42,7 +43,7 @@ class TestFcXPUFusePass(PassAutoScanTest):
# 3. matmul1 # 3. matmul1
matmul1_y_shape = draw( matmul1_y_shape = draw(
st.lists( st.lists(
st.integers(min_value=1, max_value=8), min_size=2, max_size=2 st.integers(min_value=2, max_value=8), min_size=2, max_size=2
) )
) )
matmul1_y_shape[0] = matmul0_y_shape[-1] matmul1_y_shape[0] = matmul0_y_shape[-1]
...@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest): ...@@ -101,4 +102,5 @@ class TestFcXPUFusePass(PassAutoScanTest):
if __name__ == "__main__": if __name__ == "__main__":
np.random.seed(200)
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册