未验证 提交 7c8c9b7d 编写于 作者: L leolishaohao 提交者: GitHub

[XPU] add squeeze_excitation_block_xpu op&pass to optimize ppocr_v3_det model (#56773)

* [XPU] add squeeze_excitation_block_xpu op&pass to optimize ppocr_v3_det model test=kunlun

* fix

* fix Codestype

* remove xpu name
上级 c170074d
...@@ -290,6 +290,8 @@ if(WITH_XPU) ...@@ -290,6 +290,8 @@ if(WITH_XPU)
pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS}) pass_library(fast_where_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS pass_library(fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
pass_library(squeeze_excitation_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS pass_library(elementwise_mul_add_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS}) ${XPU_PASS_DEPS})
endif() endif()
...@@ -615,4 +617,8 @@ if(WITH_XPU) ...@@ -615,4 +617,8 @@ if(WITH_XPU)
test_fast_where_xpu_fuse_pass test_fast_where_xpu_fuse_pass
SRCS xpu/fast_where_xpu_fuse_pass_test.cc SRCS xpu/fast_where_xpu_fuse_pass_test.cc
DEPS fast_where_xpu_fuse_pass) DEPS fast_where_xpu_fuse_pass)
cc_test(
test_squeeze_excitation_fuse_pass
SRCS xpu/squeeze_excitation_fuse_pass_test.cc
DEPS squeeze_excitation_fuse_pass)
endif() endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
/*
Squeeze and Excitaion Block Fusion for SE-ResNet
Origin subgraph
Input
| \
| \
| \
| |
| Global Pooling
| |
| conv2d_xpu
| |
| |
| conv2d_xpu
\ |
\ |
elementwise_mul
|
Output
------------------------------------------------------
After the pass is applied:
in_Input
in_Filter | in_FilterMax
\ | /
\ | /
in_Branch ------- squeeze_excitation_block ------ in_Bias
|
|
|
out_Output
*/
class SqueezeExcitationFusePass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph,
const std::string& op_type,
const std::string& act_type,
bool with_branch,
bool with_bias) const;
const std::string name_scope_{"squeeze_excitation_fuse_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
TEST(SqueezeExcitationFusePass, V1) {
Layers layers;
auto* block = layers.Block();
auto* pool2d_inp = layers.data("pool2d_inp", {1, 24, 14, 14});
auto* pool2d_out = layers.pool2d(pool2d_inp, false);
auto* conv2d_xpu_op1_out = layers.data("conv2d_xpu_op1_out");
OpDesc* conv2d_xpu_op1 = block->AppendOp();
conv2d_xpu_op1->SetType("conv2d_xpu");
conv2d_xpu_op1->SetInput("x", {pool2d_out->Name()});
conv2d_xpu_op1->SetOutput("out", {conv2d_xpu_op1_out->Name()});
auto* conv2d_xpu_op2_out = layers.data("conv2d_xpu_op2_out");
OpDesc* conv2d_xpu_op2 = block->AppendOp();
conv2d_xpu_op2->SetType("conv2d_xpu");
conv2d_xpu_op2->SetInput("x", {conv2d_xpu_op1_out->Name()});
conv2d_xpu_op2->SetOutput("out", {conv2d_xpu_op2_out->Name()});
layers.elementwise_mul(pool2d_inp, conv2d_xpu_op2_out);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
auto pass = PassRegistry::Instance().Get("squeeze_excitation_fuse_pass");
pass->Apply(graph.get());
auto num = GetNumOpNodes(graph, "pool2d") +
GetNumOpNodes(graph, "conv2d_xpu") +
GetNumOpNodes(graph, "elementwise_mul");
PADDLE_ENFORCE_EQ(num,
0,
platform::errors::PreconditionNotMet(
"pool2d/conv2d_xpu/elementwise_mul ops should be "
"removed from graph, but graph "
"still has %d ops. ",
num));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(squeeze_excitation_fuse_pass);
...@@ -547,6 +547,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -547,6 +547,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fc_xpu_fuse_pass", "fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass", "conv2d_xpu_fuse_pass",
"conv2d_transpose_xpu_fuse_pass", "conv2d_transpose_xpu_fuse_pass",
"squeeze_excitation_fuse_pass",
"add_activation_xpu_fuse_pass", "add_activation_xpu_fuse_pass",
"add_layernorm_xpu_fuse_pass", "add_layernorm_xpu_fuse_pass",
"fast_layernorm_xpu_fuse_pass", "fast_layernorm_xpu_fuse_pass",
......
...@@ -208,6 +208,16 @@ ...@@ -208,6 +208,16 @@
data_type : input data_type : input
optional : bias_qk optional : bias_qk
- op : squeeze_excitation_block
args : (Tensor x, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] act_type, float[] act_param, int[] filter_dims)
output : Tensor(out)
infer_meta :
func : SqueezeExcitationInferMeta
kernel :
func : squeeze_excitation_block
data_type : x
optional : bias, branch
- op : yolo_box_xpu - op : yolo_box_xpu
args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset) args : (Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output : Tensor(out), Tensor(out_max) output : Tensor(out), Tensor(out_max)
......
...@@ -1005,6 +1005,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -1005,6 +1005,7 @@ XPUOpMap& get_kl2_ops() {
{"sequence_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sequence_conv_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sequence_unpad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sequence_unpad", XPUKernelSet({phi::DataType::FLOAT32})},
// Fused op // Fused op
{"squeeze_excitation_block", XPUKernelSet({phi::DataType::FLOAT32})},
{"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"resnet_basic_block_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})}, {"resnet_basic_block", XPUKernelSet({phi::DataType::FLOAT32})},
{"fused_gemm_epilogue", {"fused_gemm_epilogue",
......
...@@ -964,4 +964,29 @@ void FusedScaleBiasReluConvBnstatsInferMeta( ...@@ -964,4 +964,29 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
eq_bias->set_dims(c_dims); eq_bias->set_dims(c_dims);
} }
void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const std::vector<int>& act_type,
const std::vector<float>& act_param,
const std::vector<int>& filter_dims,
MetaTensor* out) {
auto in_dims = x.dims();
// do some checks
PADDLE_ENFORCE_EQ(
in_dims.size(),
4,
phi::errors::InvalidArgument(
"The input should be a 4-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s].",
in_dims.size(),
in_dims));
std::vector<int64_t> out_shape(
{in_dims[0], filter_dims[1], in_dims[2], in_dims[3]});
// set output dims
out->set_dims(DDim(out_shape.data(), out_shape.size()));
}
} // namespace phi } // namespace phi
...@@ -234,4 +234,14 @@ void FusedScaleBiasReluConvBnstatsInferMeta( ...@@ -234,4 +234,14 @@ void FusedScaleBiasReluConvBnstatsInferMeta(
MetaTensor* eq_scale, MetaTensor* eq_scale,
MetaTensor* eq_bias); MetaTensor* eq_bias);
void SqueezeExcitationInferMeta(const MetaTensor& x,
const MetaTensor& filter,
const MetaTensor& filter_max,
const MetaTensor& bias,
const MetaTensor& branch,
const std::vector<int>& act_type,
const std::vector<float>& act_param,
const std::vector<int>& filter_dims,
MetaTensor* out);
} // namespace phi } // namespace phi
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename TW, typename Context>
void SqueezeExcitationKernelImpl(const Context& ctx,
const DenseTensor& x,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const std::vector<int>& act_type,
const std::vector<float>& act_param,
const std::vector<int>& filter_dims,
DenseTensor* out) {
using XPUTypeX = typename XPUTypeTrait<T>::Type;
using XPUTypeW = typename XPUTypeTrait<TW>::Type;
auto* weight1_ptr = filter.data<TW>();
auto weight_len = filter.numel();
auto weight1_len = weight_len / 2;
auto* weight2_ptr = weight1_ptr + weight1_len;
auto input_dims = x.dims();
int batch = static_cast<int>(input_dims[0]);
int channel = static_cast<int>(input_dims[1]);
int h = static_cast<int>(input_dims[2]);
int w = static_cast<int>(input_dims[3]);
auto* input_data = reinterpret_cast<const XPUTypeX*>(x.data<T>());
const XPUTypeX* branch_data = nullptr;
auto* branch_tensor = branch.get_ptr();
if (branch_tensor != nullptr) {
branch_data = reinterpret_cast<const XPUTypeX*>(branch_tensor->data<T>());
}
const float* bias1_ptr =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
const float* bias2_ptr = (bias1_ptr != nullptr)
? (bias1_ptr + filter_dims[1] / filter_dims[0])
: nullptr;
int max_ptr_size = 6;
const float* w1_maxptr = filter_max.data<float>();
const float* w2_maxptr = w1_maxptr + max_ptr_size;
auto* out_data =
reinterpret_cast<XPUTypeX*>(ctx.template Alloc<XPUTypeX>(out));
std::vector<xpu::Activation_t> act;
for (size_t i = 0; i < 3; i++) {
xpu::Activation_t cur_act = (xpu::Activation_t::act_enum)act_type[i];
if (act_type[i] == 5) {
cur_act.leaky_alpha = act_param[i];
} else if (act_type[i] == 15) {
cur_act.hard_sigmoid_slope = act_param[i];
}
act.push_back(cur_act);
}
int r = xpu::squeeze_excitation_block<T, int16_t, int16_t>(
/* baidu::xpu::api::Context* ctx */ ctx.x_context(),
/* const T* x */ input_data,
/* const TW* weight1 */ reinterpret_cast<const XPUTypeW*>(weight1_ptr),
/* const TW* weight2 */ reinterpret_cast<const XPUTypeW*>(weight2_ptr),
/* T* y */ out_data,
/* int64_t n x */ batch,
/* int64_t c x */ channel,
/* int64_t h */ h,
/* int64_t w */ w,
/* int64_t r */ filter_dims[0],
/* const float* w1_maxptr */ reinterpret_cast<const float*>(w1_maxptr),
/* const float* w2_maxptr */ reinterpret_cast<const float*>(w2_maxptr),
/* const float* bias1 x */ bias1_ptr,
/* const float* bias2 */ bias2_ptr,
/* const T* branch */ branch_data,
/* const Activation_t& excitation_act1 */ act[0],
/* const Activation_t& excitation_act2 */ act[1],
/* const Activation_t& block_act */ act[2]);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "squeeze_excitation_block");
}
template <typename T, typename Context>
void SqueezeExcitationKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& filter,
const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch,
const std::vector<int>& act_type,
const std::vector<float>& act_param,
const std::vector<int>& filter_dims,
DenseTensor* out) {
SqueezeExcitationKernelImpl<T, int16_t, Context>(ctx,
x,
filter,
filter_max,
bias,
branch,
act_type,
act_param,
filter_dims,
out);
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(squeeze_excitation_block,
XPU,
ALL_LAYOUT,
phi::fusion::SqueezeExcitationKernel,
float) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestSqueezeExcitationFusePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["squeeze_excitation_block"], (1e-3, 1e-3)
def sample_program_config(self, draw):
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=12), min_size=4, max_size=4
)
)
x_shape[1] = 24
oc = 6
conv2d_op1_w_shape = [oc, x_shape[1], 1, 1]
conv2d_op1_b_shape = [oc]
conv2d_op2_w_shape = [x_shape[1], oc, 1, 1]
conv2d_op2_b_shape = [x_shape[1]]
# Random choose if add a relu operator
has_relu = draw(st.sampled_from([True, False]))
pool2d_op = OpConfig(
type="pool2d",
inputs={"X": ["pool2d_x"]},
outputs={"Out": ["pool2d_out"]},
adaptive=True,
data_format="NCHW",
global_pooling=False,
ksize=[1, 1],
pooling_type="avg",
)
ops = [pool2d_op]
conv2d_op = OpConfig(
"conv2d",
inputs={
"Input": ["pool2d_out"],
"Filter": ["conv2d_weight"],
},
outputs={"Output": ["conv2d_out"]},
data_format="NCHW",
dilations=[1, 1],
padding_algorithm="EXPLICIT",
groups=1,
paddings=[0, 0, 0, 0],
strides=[1, 1],
has_bias=False,
)
ew_bias_op = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out"], "Y": ["ew_bias"]},
outputs={"Out": ["add_out"]},
axis=1,
)
ops.extend([conv2d_op, ew_bias_op])
conv2d_input = "add_out"
# 3. activation
if has_relu:
relu_op = OpConfig(
"relu", inputs={"X": ["add_out"]}, outputs={"Out": ["relu_out"]}
)
conv2d_input = "relu_out"
ops.append(relu_op)
conv2d_op2 = OpConfig(
"conv2d",
inputs={
"Input": [conv2d_input],
"Filter": ["conv2d_weight2"],
},
outputs={"Output": ["conv2d_out2"]},
data_format="NCHW",
dilations=[1, 1],
padding_algorithm="EXPLICIT",
groups=1,
paddings=[0, 0, 0, 0],
strides=[1, 1],
has_bias=False,
)
ew_bias_op2 = OpConfig(
"elementwise_add",
inputs={"X": ["conv2d_out2"], "Y": ["ew_bias2"]},
outputs={"Out": ["add_out2"]},
axis=1,
)
ops.extend([conv2d_op2, ew_bias_op2])
ele_mul_input = "add_out2"
# 3. activation
if has_relu:
relu_op2 = OpConfig(
"relu",
inputs={"X": ["add_out2"]},
outputs={"Out": ["relu_out2"]},
)
ele_mul_input = "relu_out2"
ops.append(relu_op2)
ew_mul_op = OpConfig(
"elementwise_mul",
inputs={"X": ["pool2d_x"], "Y": [ele_mul_input]},
outputs={"Out": ["ew_mul_out"]},
axis=-1,
)
ops.append(ew_mul_op)
program_config = ProgramConfig(
ops=ops,
weights={
"conv2d_weight": TensorConfig(
data_gen=partial(generate_data, conv2d_op1_w_shape)
),
"ew_bias": TensorConfig(shape=conv2d_op1_b_shape),
"conv2d_weight2": TensorConfig(
data_gen=partial(generate_data, conv2d_op2_w_shape)
),
"ew_bias2": TensorConfig(shape=conv2d_op2_b_shape),
},
inputs={
"pool2d_x": TensorConfig(shape=x_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["squeeze_excitation_fuse_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册