未验证 提交 599388e3 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] optimize pass (#52099)

上级 ad76d37e
......@@ -238,6 +238,7 @@ if(WITH_XPU)
${XPU_PASS_DEPS})
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -512,4 +513,8 @@ if(WITH_XPU)
test_one_beam_size_fuse_pass
SRCS xpu/one_beam_size_fuse_pass_test.cc
DEPS one_beam_size_fuse_pass)
cc_test(
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
endif()
......@@ -25,62 +25,38 @@ namespace ir {
class Graph;
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph);
int found_subgraph_count = 0;
auto nodes = graph->Nodes();
auto is_valid_reshape = [](Node* node) {
// Some cases need to consider, please refer to
// https://github.com/PaddlePaddle/Paddle/pull/49146
if (node->IsOp() && node->Op()->Type() == "reshape2") {
auto x_name = node->Op()->Input("X").front();
for (auto* var_node : node->inputs) {
if (var_node->Name() == x_name) {
if (!var_node->Var()->Persistable() && var_node->outputs.size() == 1)
return true;
}
}
}
bool InplaceOpVarPass::IsValidInplaceOp(
Node* node, const std::unordered_set<std::string>& deny_var_names) const {
if (!node->IsOp() || inplace_ops_.count(node->Op()->Type()) == 0)
return false;
};
// Record all reshape2 op's input name and output name in block 0.
// If the name used in other block, we can not inplace reshape op.
std::unordered_set<std::string> var_names, deny_var_names;
for (auto* node : nodes) {
if (is_valid_reshape(node)) {
for (auto n : node->inputs) var_names.insert(n->Name());
for (auto n : node->outputs) var_names.insert(n->Name());
}
}
for (size_t i = 1; i < graph->SubGraphsSize(); ++i) {
auto sub_graph = graph->GetSubGraph(i);
for (auto* node : sub_graph->Nodes()) {
if (node->IsOp()) {
for (auto var_node : node->inputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
for (auto var_node : node->outputs) {
if (var_names.count(var_node->Name()))
deny_var_names.insert(var_node->Name());
}
}
}
// in_var_node should only has one out_op_node
auto x_name = node->Op()->Input("X").front();
for (auto* var_node : node->inputs) {
if (var_node->Name() != x_name) continue;
if (var_node->Var()->Persistable() || var_node->outputs.size() != 1)
return false;
}
// in/out_var_node should be not used in multi graphs.
auto out_name = node->Op()->Output("Out").front();
if (deny_var_names.count(x_name) > 0 || deny_var_names.count(out_name) > 0)
return false;
return true;
}
int InplaceOpVarPass::ApplyImpl(
ir::Graph* graph,
const std::unordered_set<std::string>& deny_var_names) const {
int found_subgraph_count = 0;
// inplace all reshape op.
auto topo_nodes = TopologySortOperations(*graph);
for (auto* node : topo_nodes) {
if (!is_valid_reshape(node)) continue;
if (!IsValidInplaceOp(node, deny_var_names)) continue;
auto* op_node = node->Op();
auto input_name = op_node->Input("X")[0];
auto output_name = op_node->Output("Out")[0];
if (deny_var_names.count(input_name) || deny_var_names.count(output_name)) {
continue;
}
++found_subgraph_count;
for (auto* out_var : node->outputs) {
if (out_var->Name() == output_name) {
out_var->RenameVar(input_name);
......@@ -90,9 +66,50 @@ void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
}
}
}
op_node->RenameOutput(output_name, input_name);
op_node->Flush();
found_subgraph_count++;
}
return found_subgraph_count;
}
std::vector<std::string> InplaceOpVarPass::GetControlFlowVarNames(
ir::Graph* graph) const {
std::vector<std::string> control_flow_var_names;
for (auto* node : graph->Nodes()) {
if (!node->IsOp() || control_flow_ops_.count(node->Op()->Type()) == 0)
continue;
for (auto in_names : node->Op()->Inputs()) {
auto var_names = in_names.second;
control_flow_var_names.insert(
control_flow_var_names.end(), var_names.begin(), var_names.end());
}
for (auto out_names : node->Op()->Outputs()) {
auto var_names = out_names.second;
control_flow_var_names.insert(
control_flow_var_names.end(), var_names.begin(), var_names.end());
}
}
return control_flow_var_names;
}
void InplaceOpVarPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("inplace_op_var", graph);
if (!graph->IsMainGraph()) {
VLOG(3) << "Pass(apply in main graph) will work on all subgraphs.";
return;
}
std::unordered_set<std::string> deny_var_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
auto control_flow_var_names = GetControlFlowVarNames(graph->GetSubGraph(i));
deny_var_names.insert(control_flow_var_names.begin(),
control_flow_var_names.end());
}
int found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyImpl(graph->GetSubGraph(i), deny_var_names);
}
AddStatis(found_subgraph_count);
}
......@@ -105,4 +122,16 @@ REGISTER_PASS(inplace_op_var_pass, paddle::framework::ir::InplaceOpVarPass);
REGISTER_PASS_CAPABILITY(inplace_op_var_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reshape2", 0));
"reshape2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"squeeze2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"squeeze", 0));
......@@ -28,6 +28,18 @@ class InplaceOpVarPass : public FusePassBase {
private:
virtual ~InplaceOpVarPass() = default;
int ApplyImpl(ir::Graph* graph,
const std::unordered_set<std::string>& deny_var_names) const;
bool IsValidInplaceOp(
Node* node, const std::unordered_set<std::string>& deny_var_names) const;
std::vector<std::string> GetControlFlowVarNames(ir::Graph* graph) const;
std::set<std::string> inplace_ops_{
"reshape", "unsqueeze", "unsqueeze2", "squeeze", "squeeze2"};
std::set<std::string> control_flow_ops_{"while", "conditional_block"};
};
} // namespace ir
......
......@@ -49,10 +49,22 @@ static const std::vector<std::string> support_subgraph_passes = {
"fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
};
static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_dropout_op_pass",
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"delete_op_device_pass",
"link_xpu_op_max_pass",
};
Graph *Pass::Apply(Graph *graph) const {
......@@ -90,9 +102,15 @@ Graph *Pass::Apply(Graph *graph) const {
}
graph->Get<PassRecorder>(kPassRecorder).insert(Type());
if (graph->IsMainGraph() && std::count(support_subgraph_passes.begin(),
support_subgraph_passes.end(),
Type())) {
std::vector<std::string> subgraph_passes;
bool use_xpu = Has("use_xpu") && Get<bool>("use_xpu");
if (use_xpu) {
subgraph_passes = xpu_support_subgraph_passes;
} else {
subgraph_passes = support_subgraph_passes;
}
if (graph->IsMainGraph() &&
std::count(subgraph_passes.begin(), subgraph_passes.end(), Type())) {
for (size_t i = 1; i < graph->SubGraphsSize(); i++) {
auto *sub_graph = graph->GetSubGraph(i);
if (!sub_graph->Has(framework::ir::kParamScopeAttr)) {
......
......@@ -807,6 +807,20 @@ struct Layers {
return unary_op("logical_not", input);
}
VarDesc* stack(std::vector<VarDesc*> inputs, int axis = -1) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("stack");
std::vector<std::string> input_names;
for (auto* input : inputs) {
input_names.push_back(input->Name());
}
op->SetInput("X", input_names);
op->SetAttr("axis", axis);
op->SetOutput("Y", {out->Name()});
return out;
}
private:
VarDesc* lod_tensor(std::string name,
std::vector<int64_t> shape = {},
......
......@@ -60,10 +60,11 @@ class DeleteIsolatedNodePass : public Pass {
void DeleteIsolatedNodePass::ApplyImpl(Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
PADDLE_ENFORCE(graph->IsMainGraph(),
platform::errors::PreconditionNotMet(
"Pass(apply in main graph) will delete isolated nodes in "
"all subgraphs. Do not apply pass in subgraph."));
if (!graph->IsMainGraph()) {
VLOG(3) << "Pass(apply in main graph) will delete isolated nodes in all "
"subgraphs.";
return;
}
std::unordered_set<std::string> reserved_persistable_node_names;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct StackPattern : public PatternBase {
StackPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(stack);
// declare variable node's name
PATTERN_DECL_NODE(stack_out);
};
StackPattern::StackPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* stack = pattern->NewNode(stack_repr())
->assert_is_op("stack")
->assert_more([](Node* node) {
auto input_names = node->Op()->Input("X");
auto first_name = input_names[0];
for (auto name : input_names) {
if (name != first_name) return false;
}
return true;
});
auto* stack_out = pattern->NewNode(stack_out_repr())
->assert_is_op_output("stack", "Y")
->assert_more([](Node* node) {
std::map<std::string, std::string> support_out_ops{
{"elementwise_add", "Y"},
{"fused_multi_transformer", "SrcMask"}};
auto var_name = node->Name();
for (auto* out_node : node->outputs) {
auto op_type = out_node->Op()->Type();
if (support_out_ops.count(op_type) == 0)
return false;
auto out_op_in_names =
out_node->Op()->Input(support_out_ops[op_type]);
if (std::find(out_op_in_names.begin(),
out_op_in_names.end(),
var_name) == out_op_in_names.end())
return false;
}
return true;
});
stack->LinksTo({stack_out});
}
} // namespace patterns
/*
"stack" can be replaced by "unsqueeze" if:
1. "stack inputs" are the same。
1. "stack output" is "elementwise_add input" or "fused_multi_transformer
src_mask input".
*/
class StackFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"stack_fuse_pass"};
};
void StackFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::StackPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle StackFusePass fuse";
GET_IR_NODE(stack);
GET_IR_NODE(stack_out);
stack->RenameOp("unsqueeze2");
auto* op_desc = stack->Op();
int axis = op_desc->GetAttrIfExists<int>("axis");
op_desc->SetAttr("axes", std::vector<int>{axis});
op_desc->RemoveAttr("axis");
op_desc->MutableInputs()->at("X").resize(1);
auto* stack_in = stack->inputs[0];
IR_NODE_UNLINK(stack_in, stack);
IR_NODE_LINK_TO(stack_in, stack);
auto* outputs = op_desc->MutableOutputs();
(*outputs)["Out"] = outputs->at("Y");
outputs->erase("Y");
auto stack_out_shape = stack_out->Var()->GetShape();
if (axis < 0) axis += stack_out_shape.size();
stack_out_shape[axis] = 1;
stack_out->Var()->SetShape(stack_out_shape);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(stack_fuse_pass, paddle::framework::ir::StackFusePass);
REGISTER_PASS_CAPABILITY(stack_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"stack", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(StackFusePass, basic) {
Layers layers;
auto* block = layers.Block();
auto* stack_x = layers.data("stack_x", {-1, 64, 64});
auto* stack_out = layers.stack({stack_x, stack_x, stack_x}, 1);
stack_out->SetShape({-1, 3, 64, 64});
auto* add_x = layers.data("add_x", {-1, 24, 64, 64});
layers.elementwise_add(add_x, stack_out);
OpDesc* fused_multi_transformer_op = block->AppendOp();
fused_multi_transformer_op->SetType("fused_multi_transformer");
fused_multi_transformer_op->SetInput("SrcMask", {stack_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("stack_fuse_pass");
pass->Apply(graph.get());
auto stack_num = GetNumOpNodes(graph, "stack");
PADDLE_ENFORCE_EQ(stack_num,
0,
platform::errors::PreconditionNotMet(
"stack op should be removed from graph, but graph "
"still has %d stack op.",
stack_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(stack_fuse_pass);
......@@ -105,6 +105,9 @@ void IRPassManager::CreatePasses(Argument *argument,
new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
// "use_xpu" is used for passes in subgraphs.
pass->Set("use_xpu", new bool(argument->use_xpu()));
if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
std::string dot_file_path;
......@@ -260,7 +263,6 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("enable_int8", new bool(lite_enable_int8));
pass->Set("use_gpu", new bool(argument->use_gpu()));
pass->Set("zero_copy", new bool(argument->lite_zero_copy()));
pass->Set("use_xpu", new bool(argument->use_xpu()));
pass->Set("xpu_l3_workspace_size",
new int(argument->xpu_l3_workspace_size()));
pass->Set("use_opencl", new bool(argument->use_opencl()));
......
......@@ -1201,6 +1201,7 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetDlnnePrecisionMode(config_.dlnne_precision_mode_);
}
argument_->SetUseXpu(config_.use_xpu_);
if (config_.lite_engine_enabled()) {
argument_->SetCpuMathLibraryNumThreads(
config_.cpu_math_library_num_threads());
......@@ -1208,7 +1209,6 @@ void AnalysisPredictor::PrepareArgument() {
argument_->SetLitePassesFilter(config_.lite_passes_filter_);
argument_->SetLiteOpsFilter(config_.lite_ops_filter_);
argument_->SetLiteZeroCopy(config_.lite_zero_copy_);
argument_->SetUseXpu(config_.use_xpu_);
argument_->SetXpuL3WorkspaceSize(config_.xpu_l3_workspace_size_);
argument_->SetXpuLocked(config_.xpu_locked_);
argument_->SetXpuAutotune(config_.xpu_autotune_);
......@@ -1316,7 +1316,7 @@ void AnalysisPredictor::PrepareArgument() {
// processed in a single
if (model_precision_ != phi::DataType::FLOAT32) {
LOG(INFO) << "Model is mixed precision type with " << model_precision_
<< ", we will use a new PassStrategy. Note that only the GPU "
<< ", we will use a new PassStrategy. Note that only GPU/XPU "
"backend is supported for now.";
if (!config_.use_cinn_compiler_) {
const auto &deleted_passes = pass_builder->GetAllDeletedPasses();
......
......@@ -520,15 +520,18 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({
"delete_dropout_op_pass",
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
"link_xpu_op_max_pass",
"delete_op_device_pass",
"inplace_op_var_pass",
"delete_isolated_node_pass",
});
use_xpu_ = true;
......
......@@ -27,6 +27,9 @@ void SqueezeInferKernel(const Context& dev_ctx,
DenseTensor* out) {
auto out_dims = out->dims();
dev_ctx.template Alloc<T>(out);
if (x.Holder() == out->Holder()) {
return;
}
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
}
......
......@@ -32,6 +32,9 @@ void UnsqueezeInferKernel(const Context& dev_ctx,
}
out->Resize(out_dims);
dev_ctx.template Alloc<T>(out);
if (x.Holder() == out->Holder()) {
return;
}
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
out->Resize(out_dims); // copy will reset the dims.
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册