未验证 提交 98a165bf 编写于 作者: X xinxinZi 提交者: GitHub

add delete_xpu_unnecessary_cast_op_pass (#54663)

上级 127e9f4c
......@@ -258,6 +258,7 @@ if(WITH_XPU)
xpu DEPS ${XPU_PASS_DEPS})
pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(xpu_delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fold_two_squeeze2_fuse_pass inference DIR xpu DEPS
......@@ -548,6 +549,10 @@ if(WITH_XPU)
test_multi_encoder_xpu_adaptive_seqlen_fuse_pass
SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc
DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass)
cc_test(
test_xpu_delete_cast_op_pass
SRCS xpu/xpu_delete_cast_op_pass_test.cc
DEPS xpu_delete_cast_op_pass)
cc_test(
test_fold_interp_outsize_fuse_pass
SRCS xpu/fold_interp_outsize_fuse_pass_test.cc
......
......@@ -67,6 +67,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_pass",
"xpu_delete_cast_op_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/xpu_delete_cast_op_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct CastSoftmaxPattern : public PatternBase {
CastSoftmaxPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(softmax);
PATTERN_DECL_NODE(cast1);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(softmax_out);
PATTERN_DECL_NODE(cast1_out);
};
CastSoftmaxPattern::CastSoftmaxPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in =
pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("softmax", "X")
->assert_has_n_outputs(1);
auto* softmax = pattern->NewNode(softmax_repr())->assert_is_op("softmax");
auto* softmax_out = pattern->NewNode(softmax_out_repr())
->assert_is_op_output("softmax", "Out")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast1 =
pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP32) &&
out_dtype == static_cast<int>(proto::VarType::FP16);
});
auto* cast1_out =
pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
softmax->LinksFrom({cast0_out}).LinksTo({softmax_out});
cast1->LinksFrom({softmax_out}).LinksTo({cast1_out});
}
} // namespace patterns
int XpuDeleteCastOpPass::ApplyCastSoftmaxPass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastSoftmaxPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastSoftmaxPass fuse";
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax, softmax, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(softmax_out, softmax_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
softmax->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
softmax->Op()->RenameOutput(softmax_out->Name(), cast1_out->Name());
IR_NODE_LINK_TO(cast0_in, softmax);
IR_NODE_LINK_TO(softmax, cast1_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast0_out, softmax_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastLayerNormPattern : public PatternBase {
CastLayerNormPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(layer_norm);
PATTERN_DECL_NODE(cast1);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(layer_norm_out);
PATTERN_DECL_NODE(cast1_out);
};
CastLayerNormPattern::CastLayerNormPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in =
pattern->NewNode(cast0_in_repr())->assert_is_op_input("cast", "X");
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("layer_norm", "X")
->assert_has_n_outputs(1);
auto* layer_norm =
pattern->NewNode(layer_norm_repr())->assert_is_op("layer_norm");
auto* layer_norm_out = pattern->NewNode(layer_norm_out_repr())
->assert_is_op_output("layer_norm", "Y")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast1 =
pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP32) &&
out_dtype == static_cast<int>(proto::VarType::FP16);
});
auto* cast1_out =
pattern->NewNode(cast1_out_repr())->assert_is_op_output("cast", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
layer_norm->LinksFrom({cast0_out}).LinksTo({layer_norm_out});
cast1->LinksFrom({layer_norm_out}).LinksTo({cast1_out});
}
} // namespace patterns
int XpuDeleteCastOpPass::ApplyCastLayerNormPass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastLayerNormPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastLayerNormPass fuse";
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm, layer_norm, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(layer_norm_out, layer_norm_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
layer_norm->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
layer_norm->Op()->RenameOutput(layer_norm_out->Name(), cast1_out->Name());
IR_NODE_LINK_TO(cast0_in, layer_norm);
IR_NODE_LINK_TO(layer_norm, cast1_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast0_out, layer_norm_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
void XpuDeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
if (!graph->IsMainGraph()) {
VLOG(3) << "'xpu_delete_cast_op_pass' needs info in all "
"graphs, so it "
"should be applied in the main graph.";
return;
}
Init(name_scope_, graph);
int found_subgraph_count = ApplyCastSoftmaxPass(graph);
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastSoftmaxPass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_softmax_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastLayerNormPass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_layer_norm_cast subgraph";
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(xpu_delete_cast_op_pass,
paddle::framework::ir::XpuDeleteCastOpPass);
REGISTER_PASS_CAPABILITY(xpu_delete_cast_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"cast", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class XpuDeleteCastOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
cast(fp16->fp32)
|
softmax
|
cast(fp32->fp16)
Optimized subgraph:
softmax
*/
int ApplyCastSoftmaxPass(ir::Graph* graph) const;
/*
Origin subgraph:
cast(fp16->fp32)
|
layer_norm
|
cast(fp32->fp16)
Optimized subgraph:
layer_norm
*/
int ApplyCastLayerNormPass(ir::Graph* graph) const;
const std::string name_scope_{"xpu_delete_cast_op_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
VarDesc* AddCast(BlockDesc* block,
VarDesc* input,
int in_dtype = 5,
int out_dtype = 5) {
VarDesc* out = Data(block, input->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("cast");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_dtype", in_dtype);
op->SetAttr("out_dtype", out_dtype);
return out;
}
int GetOpNum(Graph* graph, std::string op_type = "") {
int num_nodes = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
(node->Op()->Type() == op_type || op_type.empty())) {
num_nodes++;
}
}
return num_nodes;
}
TEST(ApplyCastSoftmaxPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
auto* cast0_out = AddCast(block, cast0_in, 4, 5);
auto* softmax_out = Data(block, "softmax_out", {1});
OpDesc* softmax = block->AppendOp();
softmax->SetType("softmax");
softmax->SetInput("X", {cast0_out->Name()});
softmax->SetOutput("Out", {softmax_out->Name()});
AddCast(block, softmax_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(
GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after xpu_delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
TEST(ApplyCastLayerNormPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
auto* cast0_out = AddCast(block, cast0_in, 4, 5);
auto* layer_norm_out = Data(block, "layer_norm_out", {1});
OpDesc* layer_norm = block->AppendOp();
layer_norm->SetType("layer_norm");
layer_norm->SetInput("X", {cast0_out->Name()});
layer_norm->SetOutput("Y", {layer_norm_out->Name()});
AddCast(block, layer_norm_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("xpu_delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(
GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after xpu_delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(xpu_delete_cast_op_pass);
......@@ -538,6 +538,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"link_xpu_op_max_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
"xpu_delete_cast_op_pass",
});
use_xpu_ = true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册