未验证 提交 d13a49d6 编写于 作者: J jiangfan06 提交者: GitHub

[XPU] Add gather_squeeze_pass (#55605)

上级 9429ec48
...@@ -280,6 +280,7 @@ if(WITH_XPU) ...@@ -280,6 +280,7 @@ if(WITH_XPU)
pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(matmul_weight_trans_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS pass_library(reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(gather_squeeze_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif() endif()
......
...@@ -225,6 +225,17 @@ std::string GenAddAttrKey(Node* add_op_node) { ...@@ -225,6 +225,17 @@ std::string GenAddAttrKey(Node* add_op_node) {
return x_name + "_" + y_name + "_axis_" + std::to_string(axis); return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
} }
std::string GenTranspose2AttrKey(Node* transpose_op_node) {
auto transpose_op_desc = transpose_op_node->Op();
auto axis = transpose_op_desc->GetAttrIfExists<std::vector<int>>("axis");
std::string attr_key;
attr_key += "axis_";
for (auto x : axis) {
attr_key += std::to_string(x) + "_";
}
return attr_key;
}
std::string GenScaleAttrKey(Node* scale_op_node) { std::string GenScaleAttrKey(Node* scale_op_node) {
auto scale_op_desc = scale_op_node->Op(); auto scale_op_desc = scale_op_node->Op();
auto scale = scale_op_desc->GetAttrIfExists<float>("scale"); auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
...@@ -274,6 +285,7 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { ...@@ -274,6 +285,7 @@ void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
DeleteRepeatedOps(graph, "gather", GenGatherAttrKey); DeleteRepeatedOps(graph, "gather", GenGatherAttrKey);
DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey); DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey);
DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey); DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey);
DeleteRepeatedOps(graph, "transpose2", GenTranspose2AttrKey);
LOG(INFO) << "Round " << repeat_time++ LOG(INFO) << "Round " << repeat_time++
<< ": delete op counts: " << delete_op_count; << ": delete op counts: " << delete_op_count;
total_delete_op_count += delete_op_count; total_delete_op_count += delete_op_count;
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/xpu/gather_squeeze_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct GatherSqueeze : public PatternBase {
GatherSqueeze(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(gather);
PATTERN_DECL_NODE(squeeze2);
// declare variable node's name
PATTERN_DECL_NODE(gather_in);
PATTERN_DECL_NODE(gather_index);
PATTERN_DECL_NODE(gather_out);
}; // struct GatherSqueeze
GatherSqueeze::GatherSqueeze(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* gather_in = pattern->NewNode(gather_in_repr())
->assert_is_op_input("gather", "X")
->assert_more([](Node* node) {
for (auto* op : node->outputs) {
if (op->Op()->Type() != "gather") {
return false;
}
}
return node->outputs.size() >= 2 &&
node->Var()->GetShape().size() >= 2 &&
node->Var()->GetShape().size() < 5;
});
auto* gather_index = pattern->NewNode(gather_index_repr())
->assert_is_op_input("gather", "Index")
->assert_more([](Node* node) {
auto shape = node->Var()->GetShape();
return shape.size() == 1 && shape[0] == 1;
});
auto* gather = pattern->NewNode(gather_repr())->assert_is_op("gather");
auto* gather_out = pattern->NewNode(gather_out_repr())
->assert_is_op_output("gather", "Out")
->assert_is_op_input("squeeze2", "X");
auto* squeeze2 = pattern->NewNode(squeeze2_repr())->assert_is_op("squeeze2");
gather->LinksFrom({gather_in, gather_index}).LinksTo({gather_out});
gather_out->LinksTo({squeeze2});
}
} // namespace patterns
void GatherSqueezePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
AddTranspose(graph);
}
void GatherSqueezePass::AddTranspose(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
GraphPatternDetector gpd;
patterns::GatherSqueeze pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle GatherSqueezePass";
GET_IR_NODE(gather);
GET_IR_NODE(gather_in);
GET_IR_NODE(gather_out);
GET_IR_NODE(squeeze2);
bool flag = true;
auto var_dims = static_cast<int32_t>(gather_in->Var()->GetShape().size());
auto gather_axis = gather->Op()->GetAttrIfExists<int>("axis");
auto squeeze_axis =
squeeze2->Op()->GetAttrIfExists<std::vector<int>>("axes");
flag = flag && gather_axis == var_dims - 1;
flag = flag && (squeeze_axis == std::vector<int>{-1} ||
squeeze_axis == std::vector<int>{var_dims - 1});
if (flag) {
gather->Op()->SetAttr("axis", 0);
squeeze2->Op()->SetAttr("axes", std::vector<int>{0});
std::string transpose2_out_name =
patterns::PDNodeName(name_scope_, "transpose2out");
VarDesc transpose2_out_vardesc(transpose2_out_name);
OpDesc transpose2_op_desc(gather->Op()->Block());
auto gather_in_shape = gather_in->Var()->GetShape();
auto gather_out_shape = gather_out->Var()->GetShape();
transpose2_out_vardesc.SetDataType(gather_in->Var()->GetDataType());
if (var_dims == 2) {
gather_out->Var()->SetShape({gather_out_shape[1], gather_out_shape[0]});
transpose2_out_vardesc.SetShape(
{gather_in_shape[1], gather_in_shape[0]});
transpose2_op_desc.SetAttr("axis", std::vector<int>{1, 0});
} else if (var_dims == 3) {
gather_out->Var()->SetShape(
{gather_out_shape[2], gather_out_shape[0], gather_out_shape[1]});
transpose2_out_vardesc.SetShape(
{gather_in_shape[2], gather_in_shape[0], gather_in_shape[1]});
transpose2_op_desc.SetAttr("axis", std::vector<int>{2, 0, 1});
} else {
gather_out->Var()->SetShape({gather_out_shape[3],
gather_out_shape[0],
gather_out_shape[1],
gather_out_shape[2]});
transpose2_out_vardesc.SetShape({gather_in_shape[3],
gather_in_shape[0],
gather_in_shape[1],
gather_in_shape[2]});
transpose2_op_desc.SetAttr("axis", std::vector<int>{3, 0, 1, 2});
}
auto* transpose2_out = graph->CreateVarNode(&transpose2_out_vardesc);
transpose2_op_desc.SetType("transpose2");
transpose2_op_desc.SetInput("X", {gather_in->Name()});
transpose2_op_desc.SetOutput("Out", {transpose2_out->Name()});
auto* transpose2 = graph->CreateOpNode(&transpose2_op_desc);
gather->Op()->SetInput("X", {transpose2_out->Name()});
IR_NODE_UNLINK(gather_in, gather);
IR_NODE_LINK_TO(gather_in, transpose2);
IR_NODE_LINK_TO(transpose2, transpose2_out);
IR_NODE_LINK_TO(transpose2_out, gather);
found_subgraph_count++;
}
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(gather_squeeze_pass, paddle::framework::ir::GatherSqueezePass);
REGISTER_PASS_CAPABILITY(gather_squeeze_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("gather", 1)
.EQ("squeeze2", 0));
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class GatherSqueezePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
add transpose2 before gather + squeeze2
For example:
graph:
x
|
gather (axis = -1)
|
squeeze2 (axis = -1)
|
output
------------------------------------------------------
After the pass is applied:
x
|
transpose2 (2, 0, 1)
|
gather (axis = 0)
|
squeeze2 (axis = 0)
|
output
*/
void AddTranspose(ir::Graph* graph) const;
const std::string name_scope_{"gather_squeeze_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -509,6 +509,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -509,6 +509,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_assign_op_pass", "delete_assign_op_pass",
"delete_dropout_op_pass", "delete_dropout_op_pass",
"delete_concat_op_pass", "delete_concat_op_pass",
"gather_squeeze_pass",
"delete_repeated_ops_pass", "delete_repeated_ops_pass",
"identity_op_clean_pass", "identity_op_clean_pass",
"fused_continuous_same_ops_pass", "fused_continuous_same_ops_pass",
......
...@@ -26,7 +26,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -26,7 +26,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"add_layernorm_xpu", {"add_layernorm_xpu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs", XPUKernelSet({phi::DataType::FLOAT32})}, {"abs", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"abs_grad", {"abs_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"accuracy", XPUKernelSet({phi::DataType::FLOAT32})}, {"accuracy", XPUKernelSet({phi::DataType::FLOAT32})},
......
...@@ -26,16 +26,18 @@ void AbsGradKernel(const Context& ctx, ...@@ -26,16 +26,18 @@ void AbsGradKernel(const Context& ctx,
const DenseTensor& dout, const DenseTensor& dout,
DenseTensor* dx) { DenseTensor* dx) {
ctx.template Alloc<T>(dx); ctx.template Alloc<T>(dx);
using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu::abs_grad(ctx.x_context(), int r = xpu::abs_grad(ctx.x_context(),
x.data<T>(), reinterpret_cast<const XPUType*>(x.data<T>()),
dout.data<T>(), reinterpret_cast<const XPUType*>(dout.data<T>()),
dout.data<T>(), reinterpret_cast<const XPUType*>(dout.data<T>()),
dx->data<T>(), reinterpret_cast<XPUType*>(dx->data<T>()),
x.numel()); x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs_grad");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) { PD_REGISTER_KERNEL(
abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float, phi::dtype::float16) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
} }
...@@ -22,9 +22,14 @@ namespace phi { ...@@ -22,9 +22,14 @@ namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out); ctx.template Alloc<T>(out);
int r = xpu::abs(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel()); using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu::abs<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs");
} }
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(abs, XPU, ALL_LAYOUT, phi::AbsKernel, float) {} PD_REGISTER_KERNEL(
abs, XPU, ALL_LAYOUT, phi::AbsKernel, float, phi::dtype::float16) {}
...@@ -727,5 +727,90 @@ class TestDeleteRepeatedGatherPass(PassAutoScanTest): ...@@ -727,5 +727,90 @@ class TestDeleteRepeatedGatherPass(PassAutoScanTest):
) )
class TestDeleteRepeatedTransposePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['transpose2', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
batch_size = draw(st.integers(min_value=1, max_value=4))
H = draw(st.integers(min_value=1, max_value=64))
W = draw(st.integers(min_value=1, max_value=64))
in_shape = [batch_size, H, W]
axis = [0, 2, 1]
transpose_op0 = OpConfig(
type='transpose2',
inputs={
"X": ["transpose_x"],
},
outputs={"Out": ["transpose_output0"]},
attrs={"axis": axis},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["transpose_output0"],
},
outputs={"Out": ["relu0_out"]},
)
transpose_op1 = OpConfig(
type='transpose2',
inputs={
"X": ["transpose_x"],
},
outputs={"Out": ["transpose_output1"]},
attrs={"axis": axis},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["transpose_output1"],
},
outputs={"Out": ["relu1_out"]},
)
transpose_op2 = OpConfig(
type='transpose2',
inputs={
"X": ["transpose_x"],
},
outputs={"Out": ["transpose_output2"]},
attrs={"axis": axis},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["transpose_output2"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [
transpose_op0,
relu_op0,
transpose_op1,
relu_op1,
transpose_op2,
relu_op2,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"transpose_x": TensorConfig(shape=in_shape),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestGatherAddTransposePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, [
"transpose2",
"gather",
"transpose2",
"gather",
"squeeze2",
"squeeze2",
], (1e-3, 1e-3)
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=4), min_size=3, max_size=3
)
)
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
def generate_index(*args, **kwargs):
return np.array([0]).astype(np.int64)
axis = 2
axes = [2]
gather_op0 = OpConfig(
"gather",
inputs={"X": ["gather_in"], "Index": ["gather_index0"]},
outputs={"Out": ["gather_out0"]},
axis=axis,
)
gather_op1 = OpConfig(
"gather",
inputs={"X": ["gather_in"], "Index": ["gather_index1"]},
outputs={"Out": ["gather_out1"]},
axis=axis,
)
squeeze_op0 = OpConfig(
"squeeze2",
inputs={
"X": ["gather_out0"],
},
outputs={"Out": ["squeeze_out0"]},
axes=axes,
)
squeeze_op1 = OpConfig(
"squeeze2",
inputs={
"X": ["gather_out1"],
},
outputs={"Out": ["squeeze_out1"]},
axes=axes,
)
ops = [gather_op0, gather_op1, squeeze_op0, squeeze_op1]
program_config = ProgramConfig(
ops=ops,
inputs={
"gather_in": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
"gather_index0": TensorConfig(data_gen=partial(generate_index)),
"gather_index1": TensorConfig(data_gen=partial(generate_index)),
},
weights={},
outputs=["squeeze_out0", "squeeze_out1"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False, max_examples=25, passes=["gather_squeeze_pass"]
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册