未验证 提交 17d6d932 编写于 作者: W wz1qqx 提交者: GitHub

[XPU]fuse small ops of idg models (#54245)

上级 a087b9cb
......@@ -252,6 +252,8 @@ if(WITH_XPU)
xpu DEPS ${XPU_PASS_DEPS})
pass_library(add_activation_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fold_interp_outsize_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -536,4 +538,8 @@ if(WITH_XPU)
test_multi_encoder_xpu_adaptive_seqlen_fuse_pass
SRCS xpu/multi_encoder_xpu_adaptive_seqlen_fuse_pass_test.cc
DEPS multi_encoder_xpu_adaptive_seqlen_fuse_pass)
cc_test(
test_fold_interp_outsize_fuse_pass
SRCS xpu/fold_interp_outsize_fuse_pass_test.cc
DEPS fold_interp_outsize_fuse_pass)
endif()
......@@ -361,20 +361,31 @@ struct Layers {
return outs;
}
std::vector<VarDesc*> split(VarDesc* x, int num_or_section, int axis = 0) {
std::vector<VarDesc*> outs(num_or_section);
for (int i = 0; i < num_or_section; i++) {
std::vector<VarDesc*> split(VarDesc* x,
int num_or_section = 0,
int axis = 0,
std::vector<int> sections = {-1}) {
int out_num = num_or_section;
if (num_or_section == 0) {
out_num = sections.size();
}
std::vector<VarDesc*> outs(out_num);
for (int i = 0; i < out_num; i++) {
outs[i] = lod_tensor(unique_name());
}
std::vector<std::string> out_names(num_or_section);
for (int i = 0; i < num_or_section; i++) {
std::vector<std::string> out_names(out_num);
for (int i = 0; i < out_num; i++) {
out_names[i] = outs[i]->Name();
}
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("split");
op->SetInput("X", {x->Name()});
op->SetOutput("Out", out_names);
op->SetAttr("num_or_section", num_or_section);
if (num_or_section == 0) {
op->SetAttr("sections", sections);
} else {
op->SetAttr("num_or_section", num_or_section);
}
op->SetAttr("axis", axis);
op->SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(OpRole::kForward));
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/fold_interp_outsize_fuse_pass.h"
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct DetectorFusePattern : public PatternBase {
DetectorFusePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(shape);
PATTERN_DECL_NODE(cast1);
PATTERN_DECL_NODE(slice);
PATTERN_DECL_NODE(concat);
PATTERN_DECL_NODE(split);
PATTERN_DECL_NODE(cast2);
PATTERN_DECL_NODE(bilinear_interp);
// declare variable node's name
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(shape_out);
PATTERN_DECL_NODE(cast1_out);
PATTERN_DECL_NODE(slice_out);
PATTERN_DECL_NODE(concat_y);
PATTERN_DECL_NODE(concat_out);
PATTERN_DECL_NODE(split_out_0);
PATTERN_DECL_NODE(split_out_1);
PATTERN_DECL_NODE(cast2_out);
};
DetectorFusePattern::DetectorFusePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* x = pattern->NewNode(x_repr())
->assert_is_op_input("shape", "Input")
->assert_is_op_input("bilinear_interp_v2", "X");
auto* shape = pattern->NewNode(shape_repr())->assert_is_op("shape");
auto* shape_out = pattern->NewNode(shape_out_repr())
->assert_is_op_output("shape", "Out")
->assert_is_op_input("cast", "X");
shape->LinksFrom({x}).LinksTo({shape_out});
auto* cast1 = pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("in_dtype") == 2 &&
op_desc->GetAttrIfExists<int>("out_dtype") == 3;
});
auto* cast1_out = pattern->NewNode(cast1_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("slice", "Input");
cast1->LinksFrom({shape_out}).LinksTo({cast1_out});
auto* slice =
pattern->NewNode(slice_repr())
->assert_is_op("slice")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<std::vector<int>>("axes") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("starts") ==
std::vector<int>{0} &&
op_desc->GetAttrIfExists<std::vector<int>>("ends") ==
std::vector<int>{2};
});
auto* slice_out = pattern->NewNode(slice_out_repr())
->assert_is_op_output("slice", "Out")
->assert_is_op_nth_input("concat", "X", 0);
slice->LinksFrom({cast1_out}).LinksTo({slice_out});
auto* concat = pattern->NewNode(concat_repr())
->assert_is_op("concat")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("axis") == 0;
});
auto* concat_y = pattern->NewNode(concat_y_repr())
->assert_is_op_nth_input("concat", "X", 1)
->assert_is_persistable_var();
auto* concat_out = pattern->NewNode(concat_out_repr())
->assert_is_op_output("concat", "Out")
->assert_is_op_input("split", "X");
concat->LinksFrom({slice_out, concat_y}).LinksTo({concat_out});
auto* split = pattern->NewNode(split_repr())
->assert_is_op("split")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("axis") == 0 &&
(op_desc->GetAttrIfExists<std::vector<int>>(
"sections") == std::vector<int>{2, 2} ||
op_desc->GetAttrIfExists<int>("num") == 2);
});
auto* split_out_0 = pattern->NewNode(split_out_0_repr())
->assert_is_op_nth_output("split", "Out", 0);
auto* split_out_1 = pattern->NewNode(split_out_1_repr())
->assert_is_op_nth_output("split", "Out", 1)
->assert_is_op_input("cast", "X");
split->LinksFrom({concat_out}).LinksTo({split_out_0, split_out_1});
auto* cast2 = pattern->NewNode(cast2_repr())
->assert_is_op("cast")
->assert_more([&](Node* node) {
auto* op_desc = node->Op();
return op_desc->GetAttrIfExists<int>("in_dtype") == 3 &&
op_desc->GetAttrIfExists<int>("out_dtype") == 2;
});
auto* cast2_out = pattern->NewNode(cast2_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("bilinear_interp_v2", "OutSize");
cast2->LinksFrom({split_out_1}).LinksTo({cast2_out});
auto* bilinear_interp = pattern->NewNode(bilinear_interp_repr())
->assert_is_op("bilinear_interp_v2");
bilinear_interp->LinksFrom({x, cast2_out});
}
} // namespace patterns
void FoldInterpOutsizeFusePass::DetectorFuse(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::DetectorFusePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DetectorFuse";
/* declare operator node's name */
GET_IR_NODE(shape);
GET_IR_NODE(cast1);
GET_IR_NODE(slice);
GET_IR_NODE(concat);
GET_IR_NODE(split);
GET_IR_NODE(cast2);
GET_IR_NODE(bilinear_interp);
/* declare variable node's name*/
GET_IR_NODE(x);
GET_IR_NODE(shape_out);
GET_IR_NODE(cast1_out);
GET_IR_NODE(slice_out);
GET_IR_NODE(concat_y);
GET_IR_NODE(concat_out);
GET_IR_NODE(split_out_0);
GET_IR_NODE(split_out_1);
GET_IR_NODE(cast2_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
auto* concat_y_t =
scope->GetVar(concat_y->Name())->GetMutable<phi::DenseTensor>();
// concat_y int64 --> int32
auto tensor_type = concat_y_t->dtype();
if (tensor_type == phi::DataType::INT64) {
CastToInt32(concat_y_t, nullptr);
}
bilinear_interp->Op()->RenameInput(cast2_out->Name(), concat_y->Name());
IR_NODE_UNLINK(x, shape);
IR_NODE_UNLINK(cast2_out, bilinear_interp);
IR_NODE_LINK_TO(concat_y, bilinear_interp);
// delete useless node
std::unordered_set<const Node*> delete_nodes = {shape,
cast1,
slice,
concat,
split,
cast2,
shape_out,
cast1_out,
slice_out,
concat_out,
split_out_0,
split_out_1,
cast2_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
void FoldInterpOutsizeFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
DetectorFuse(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fold_interp_outsize_fuse_pass,
paddle::framework::ir::FoldInterpOutsizeFusePass);
REGISTER_PASS_CAPABILITY(fold_interp_outsize_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"shape", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class FoldInterpOutsizeFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
x
/ \
| shape
| |
| cast
| |
| slice
| |
| concat
| |
| split
| | \
| | \
| outvar_1 outvar_0
| |
| cast
| /
\ /
bilinear_interp_v2
Fused subgraph:
x
| concat_y
| /
bilinear_interp_v2
*/
void DetectorFuse(ir::Graph* graph) const;
const std::string name_scope_{"fold_interp_outsize_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(DetectorFuse, basic) {
Layers layers;
auto* block = layers.Block();
auto* shape_x = layers.data("shape_x", {1, 18, 288, 288});
auto* concat_y =
layers.data("concat_y", {576, 576}, true, proto::VarType::INT64);
auto* shape_out = layers.shape(shape_x);
auto* cast1_out = layers.cast(shape_out, 2, 3);
auto* slice_out = layers.slice(cast1_out, {0}, {0}, {2});
auto* concat_out = layers.concat({slice_out, concat_y}, 0);
auto split_outs = layers.split(concat_out, 0, 0, {2, 2});
auto* split_out_1 = split_outs[1];
auto* cast2_out = layers.cast(split_out_1, 3, 2);
OpDesc* bilinear_interp_v2_op = block->AppendOp();
bilinear_interp_v2_op->SetType("bilinear_interp_v2");
bilinear_interp_v2_op->SetInput("X", {shape_x->Name()});
bilinear_interp_v2_op->SetInput("OutSize", {cast2_out->Name()});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("fold_interp_outsize_fuse_pass");
pass->Apply(graph.get());
auto ops_num = GetNumOpNodes(graph);
PADDLE_ENFORCE_EQ(
ops_num,
1,
platform::errors::PreconditionNotMet(
"graph should only have 2 op nodes, but received %d.", ops_num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(fold_interp_outsize_fuse_pass);
......@@ -70,6 +70,39 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out) {
}
}
void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out) {
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
phi::DenseTensor int32_tensor;
phi::DenseTensor* out_ptr = out == nullptr ? &int32_tensor : out;
out_ptr->Resize(in->dims());
out_ptr->set_type(phi::DataType::INT32);
out_ptr->set_layout(in->layout());
switch (in->dtype()) {
case phi::DataType::INT64:
phi::CastKernel<int64_t>(*cpu_ctx, *in, phi::DataType::INT32, out_ptr);
break;
case phi::DataType::INT32:
if (out == nullptr) {
return;
} else {
phi::AssignKernel(*cpu_ctx, *in, out_ptr);
}
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support int64 and int32, but received dtype is %s.",
phi::DataTypeToString(in->dtype())));
break;
}
if (out == nullptr) {
Assign(*out_ptr, in);
}
}
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out) {
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
......
......@@ -25,6 +25,8 @@ void Transpose2D(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
void CastToFp32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
void CastToInt32(phi::DenseTensor* in, phi::DenseTensor* out = nullptr);
// 1. Quant weight from fp32 to int16/int31
// 2. Weight data is in-place update.
// 3. Generate weight max tensor
......
......@@ -522,6 +522,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_slice_fuse_pass",
"fused_multi_transformer_cachekv_layout_trans_pass",
"one_beam_size_fuse_pass",
"fold_interp_outsize_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_pass",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册