未验证 提交 7f6bb160 编写于 作者: C csy0225 提交者: GitHub

[XPU][Inference] Delete redundant squeeze/unsqueeze op. (#54754)

上级 55704db5
......@@ -236,6 +236,8 @@ if(WITH_XPU)
set(XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils)
pass_library(yolo_box_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(conv2d_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(redundant_squeeze_unsqueeze_elimination_pass inference DIR xpu
DEPS ${XPU_PASS_DEPS})
pass_library(embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(fc_xpu_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
// Delete redundant squeeze/unsqueeze op
/*
For example:
graph:
Input
|
|
squeeze
|
|
squeeze out
|
|
activation(leaky_relu)
|
|
activation out
|
|
unsqueeze
|
|
Output
------------------------------------------------------
After the pass is applied:
Input
|
|
activation(leaky_relu)
|
|
Output
*/
struct SqueezeActivationUnsqueezeEliminationPattern : public PatternBase {
SqueezeActivationUnsqueezeEliminationPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type);
// declare operator node's name
PATTERN_DECL_NODE(squeeze);
PATTERN_DECL_NODE(act);
PATTERN_DECL_NODE(unsqueeze);
// declare variable node's name
PATTERN_DECL_NODE(squeeze_input);
PATTERN_DECL_NODE(squeeze_out);
PATTERN_DECL_NODE(act_out);
PATTERN_DECL_NODE(unsqueeze_out);
private:
std::string act_type_;
};
SqueezeActivationUnsqueezeEliminationPattern::
SqueezeActivationUnsqueezeEliminationPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act_type)
: PatternBase(pattern, name_scope, name_scope), act_type_(act_type) {
// squeeze
auto squeeze = pattern->NewNode(squeeze_repr())->assert_is_op("squeeze2");
auto squeeze_input = pattern->NewNode(squeeze_input_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput();
auto squeeze_out = pattern->NewNode(squeeze_out_repr())
->assert_is_op_output("squeeze2", "Out");
squeeze->LinksFrom({squeeze_input}).LinksTo({squeeze_out});
// activation
auto act = pattern->NewNode(act_repr())->assert_is_op(act_type_);
auto act_out =
pattern->NewNode(act_out_repr())->assert_is_op_output(act_type_, "Out");
squeeze_out->assert_is_op_input(act_type_, "X");
act->LinksFrom({squeeze_out}).LinksTo({act_out});
// unsqueeze
auto unsqueeze =
pattern->NewNode(unsqueeze_repr())->assert_is_op("unsqueeze2");
auto unsqueeze_out = pattern->NewNode(unsqueeze_out_repr())
->assert_is_op_output("unsqueeze2", "Out")
->AsOutput();
act_out->assert_is_op_input("unsqueeze2", "X");
unsqueeze->LinksFrom({act_out}).LinksTo({unsqueeze_out});
}
/*
Function Description:Delete redundant squeeze/unsqueeze op
Pattern: custom pattern
For example:
graph:
Input1
|
|
squeeze1
|
|
squeeze1 out Input2
| |
| |
activation1(leaky_relu) squeeze2
| |
| |
activation1 out squeeze2 out
| |
| |
- - - - elementwise operation(elementwise_add) - - - -
|
|
activation2(leaky_relu)
|
|
activation2 out
|
|
- - - - - - - - - - - - - - - - - - -
| | | |
| | | |
unsqueeze 1 ...... unsqueeze n-1 unsqueeze n
| | | |
| | | |
Output 1 ...... Output n-1 Output n
------------------------------------------------------
After the pass is applied:
Input1
|
|
activation1(leaky_relu)
|
|
activation1 out Input2
| |
| |
- - - - elementwise operation(elementwise_add) - - - -
|
|
activation2(leaky_relu)
|
|
activation2 out
|
|
- - - - - - - - - - - - - - - - - - -
| | | |
| | | |
Output 1 ...... Output n-1 Output n
*/
struct CustomSqueezeUnsqueezeEliminationPattern : public PatternBase {
CustomSqueezeUnsqueezeEliminationPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& act1_type,
const std::string& act2_type,
const std::string& elementwise_type,
const bool act1_in_branch_x);
// declare operator node's name
PATTERN_DECL_NODE(squeeze1);
PATTERN_DECL_NODE(squeeze2);
PATTERN_DECL_NODE(act1);
PATTERN_DECL_NODE(elementwise);
PATTERN_DECL_NODE(act2);
// declare variable node's name
PATTERN_DECL_NODE(squeeze1_input);
PATTERN_DECL_NODE(squeeze1_out);
PATTERN_DECL_NODE(act1_out);
PATTERN_DECL_NODE(squeeze2_input);
PATTERN_DECL_NODE(squeeze2_out);
PATTERN_DECL_NODE(elementwise_out);
PATTERN_DECL_NODE(act2_out);
private:
std::string act1_type_;
std::string act2_type_;
std::string elementwise_type_;
bool act1_in_branch_x_;
};
CustomSqueezeUnsqueezeEliminationPattern::
CustomSqueezeUnsqueezeEliminationPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& act1_type,
const std::string& act2_type,
const std::string& elementwise_type,
const bool act1_in_branch_x)
: PatternBase(pattern, name_scope, name_scope),
act1_type_(act1_type),
act2_type_(act2_type),
elementwise_type_(elementwise_type),
act1_in_branch_x_(act1_in_branch_x) {
// squeeze1
auto squeeze1 = pattern->NewNode(squeeze1_repr())->assert_is_op("squeeze2");
auto squeeze1_input = pattern->NewNode(squeeze1_input_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput();
auto squeeze1_out = pattern->NewNode(squeeze1_out_repr())
->assert_is_op_output("squeeze2", "Out");
squeeze1->LinksFrom({squeeze1_input}).LinksTo({squeeze1_out});
// activation1
auto act1 = pattern->NewNode(act1_repr())->assert_is_op(act1_type_);
auto act1_out =
pattern->NewNode(act1_out_repr())->assert_is_op_output(act1_type_, "Out");
squeeze1_out->assert_is_op_input(act1_type_, "X");
act1->LinksFrom({squeeze1_out}).LinksTo({act1_out});
// squeeze2
auto squeeze2 = pattern->NewNode(squeeze2_repr())->assert_is_op("squeeze2");
auto squeeze2_input = pattern->NewNode(squeeze2_input_repr())
->assert_is_op_input("squeeze2", "X")
->AsInput();
auto squeeze2_out = pattern->NewNode(squeeze2_out_repr())
->assert_is_op_output("squeeze2", "Out");
squeeze2->LinksFrom({squeeze2_input}).LinksTo({squeeze2_out});
// elementwise
auto elementwise =
pattern->NewNode(elementwise_repr())->assert_is_op(elementwise_type_);
auto elementwise_out = pattern->NewNode(elementwise_out_repr())
->assert_is_op_output(elementwise_type_, "Out");
if (act1_in_branch_x_) {
act1_out->assert_is_op_input(elementwise_type_, "X");
squeeze2_out->assert_is_op_input(elementwise_type_, "Y");
} else {
act1_out->assert_is_op_input(elementwise_type_, "Y");
squeeze2_out->assert_is_op_input(elementwise_type_, "X");
}
elementwise->LinksFrom({act1_out, squeeze2_out}).LinksTo({elementwise_out});
// activation2
auto act2 = pattern->NewNode(act2_repr())->assert_is_op(act2_type_);
auto act2_out =
pattern->NewNode(act2_out_repr())->assert_is_op_output(act2_type_, "Out");
elementwise_out->assert_is_op_input(act2_type_, "X");
act2->LinksFrom({elementwise_out}).LinksTo({act2_out});
act2_out->AsOutput();
}
} // namespace patterns
class SqueezeActivationUnsqueezeEliminationPass : public FusePassBase {
public:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph, const std::string& act_type) const;
const std::string name_scope_{
"squeeze_activation_unsqueeze_elimination_pass"};
};
void SqueezeActivationUnsqueezeEliminationPass::ApplyImpl(
ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> support_act_type{"relu",
"sigmoid",
"tanh",
"gelu",
"leaky_relu",
"hard_swish",
"hard_sigmoid",
"relu6",
"swish"};
int found_subgraph_count = 0;
for (auto act_type : support_act_type) {
found_subgraph_count += ApplyImpl(graph, act_type);
}
AddStatis(found_subgraph_count);
}
int SqueezeActivationUnsqueezeEliminationPass::ApplyImpl(
ir::Graph* graph, const std::string& act_type) const {
GraphPatternDetector gpd;
patterns::SqueezeActivationUnsqueezeEliminationPattern pattern(
gpd.mutable_pattern(), name_scope_, act_type);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle squeeze activation unsqueeze elimination.";
/* Get operator node's name */
GET_IR_NODE(squeeze);
GET_IR_NODE(act);
GET_IR_NODE(unsqueeze);
/* Get variable node's name*/
GET_IR_NODE(squeeze_input);
GET_IR_NODE(squeeze_out);
GET_IR_NODE(act_out);
GET_IR_NODE(unsqueeze_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
// Judge squeeze1 && squeeze2 op shape is same or not, if axes is same, the
// shape is same too.
std::vector<int> squeeze_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze->Op()->GetAttr("axes"));
std::vector<int> unsqueeze_axes =
PADDLE_GET_CONST(std::vector<int>, unsqueeze->Op()->GetAttr("axes"));
bool elimination = (squeeze_axes == unsqueeze_axes);
if (!elimination) return;
// act
auto act_op_desc = act->Op();
act_op_desc->RenameInput(squeeze_out->Var()->Name(),
squeeze_input->Var()->Name());
act_out->Var()->SetShape(squeeze_input->Var()->GetShape());
act_op_desc->Flush();
IR_NODE_LINK_TO(squeeze_input, act);
// behind unsqueeze op node
auto unsqueeze_out_link_nodes = unsqueeze_out->outputs;
for (auto out_link_node : unsqueeze_out_link_nodes) {
auto op_desc = out_link_node->Op();
op_desc->RenameInput(unsqueeze_out->Var()->Name(),
act_out->Var()->Name());
op_desc->Flush();
IR_NODE_LINK_TO(act_out, out_link_node);
}
std::unordered_set<const Node*> delete_nodes{
squeeze, squeeze_out, unsqueeze, unsqueeze_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
class CustomSqueezeUnsqueezeEliminationPass : public FusePassBase {
public:
void ApplyImpl(ir::Graph* graph) const override;
private:
int ApplyImpl(ir::Graph* graph,
const std::string& act1_type,
const std::string& act2_type,
const std::string& elementwise_type,
bool act1_in_branch_x) const;
const std::string name_scope_{"custom_squeeze_unsqueeze_elimination_pass"};
};
void CustomSqueezeUnsqueezeEliminationPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
std::vector<std::string> support_act_type{"relu",
"sigmoid",
"tanh",
"gelu",
"leaky_relu",
"hard_swish",
"hard_sigmoid",
"relu6",
"swish"};
std::vector<std::string> support_elementwise_type{"elementwise_add",
"elementwise_sub",
"elementwise_mul",
"elementwise_div"};
int found_subgraph_count = 0;
for (auto act1_type : support_act_type) {
for (auto act2_type : support_act_type) {
for (auto elementwise_type : support_elementwise_type) {
for (auto act1_in_branch_x : {true, false}) {
found_subgraph_count += ApplyImpl(
graph, act1_type, act2_type, elementwise_type, act1_in_branch_x);
}
}
}
}
AddStatis(found_subgraph_count);
}
int CustomSqueezeUnsqueezeEliminationPass::ApplyImpl(
ir::Graph* graph,
const std::string& act1_type,
const std::string& act2_type,
const std::string& elementwise_type,
const bool act1_in_branch_x) const {
GraphPatternDetector gpd;
patterns::CustomSqueezeUnsqueezeEliminationPattern pattern(
gpd.mutable_pattern(),
name_scope_,
act1_type,
act2_type,
elementwise_type,
act1_in_branch_x);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle custom squeeze unsqueeze elimination pass.";
/* Get operator node's name */
GET_IR_NODE(squeeze1);
GET_IR_NODE(squeeze2);
GET_IR_NODE(act1);
GET_IR_NODE(elementwise);
GET_IR_NODE(act2);
/* Get variable node's name*/
GET_IR_NODE(squeeze1_input);
GET_IR_NODE(squeeze1_out);
GET_IR_NODE(act1_out);
GET_IR_NODE(squeeze2_input);
GET_IR_NODE(squeeze2_out);
GET_IR_NODE(elementwise_out);
GET_IR_NODE(act2_out);
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument("Scope cannot be nullptr."));
std::unordered_set<const Node*> delete_nodes;
// Judge squeeze1 && squeeze2 op shape is same or not, if axes is same, the
// shape is same too.
std::vector<int> squeeze1_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze1->Op()->GetAttr("axes"));
std::vector<int> squeeze2_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze2->Op()->GetAttr("axes"));
bool elimination = (squeeze1_axes == squeeze2_axes);
if (!elimination) return;
// act1
auto act1_op_desc = act1->Op();
std::string squeeze1_input_var_name = squeeze1_input->Var()->Name();
std::string squeeze1_out_var_name = squeeze1_out->Var()->Name();
act1_op_desc->RenameInput(squeeze1_out_var_name, squeeze1_input_var_name);
act1_out->Var()->SetShape(squeeze1_input->Var()->GetShape());
act1_op_desc->Flush();
IR_NODE_LINK_TO(squeeze1_input, act1);
// elementwise
auto elementwise_op_desc = elementwise->Op();
std::string squeeze2_input_var_name = squeeze2_input->Var()->Name();
std::string squeeze2_out_var_name = squeeze2_out->Var()->Name();
elementwise_op_desc->RenameInput(squeeze2_out_var_name,
squeeze2_input_var_name);
elementwise_out->Var()->SetShape(squeeze2_input->Var()->GetShape());
elementwise_op_desc->Flush();
IR_NODE_LINK_TO(squeeze2_input, elementwise);
std::string act2_out_var_name = act2_out->Var()->Name();
std::vector<Node*> remove_nodes;
auto act2_out_link_nodes = act2_out->outputs;
for (auto out_link_node : act2_out_link_nodes) {
auto op_desc = out_link_node->Op();
if (op_desc->Type() == "unsqueeze2") {
std::vector<int> unsqueeze_axes =
PADDLE_GET_CONST(std::vector<int>, op_desc->GetAttr("axes"));
elimination = elimination && (unsqueeze_axes == squeeze1_axes);
if (elimination) {
remove_nodes.push_back(out_link_node);
delete_nodes.insert(out_link_node);
}
}
}
if (!elimination) return;
act2_out->Var()->SetShape(elementwise_out->Var()->GetShape());
for (auto unsqueeze_node : remove_nodes) {
std::string unsqueeze_out_var_name =
unsqueeze_node->Op()->Output("Out")[0];
for (auto unsqueeze_out_node : unsqueeze_node->outputs) {
// find unsqueeze "Out" var node
if (unsqueeze_out_node->Name() == unsqueeze_out_var_name) {
// Do delete operation
delete_nodes.insert(unsqueeze_out_node);
for (auto next_node : unsqueeze_out_node->outputs) {
auto next_op_desc = next_node->Op();
next_op_desc->RenameInput(unsqueeze_out_var_name,
act2_out_var_name);
next_op_desc->Flush();
IR_NODE_LINK_TO(act2_out, next_node);
}
}
}
}
if (elimination) {
delete_nodes.insert(squeeze1);
delete_nodes.insert(squeeze2);
delete_nodes.insert(squeeze1_out);
delete_nodes.insert(squeeze2_out);
}
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
class RedundantSqueezeUnsqueezeEliminationPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"redundant_squeeze_unsqueeze_elimination_pass"};
};
void RedundantSqueezeUnsqueezeEliminationPass::ApplyImpl(
ir::Graph* graph) const {
VLOG(4) << "handle redundant squeeze unsqueeze elimination.";
SqueezeActivationUnsqueezeEliminationPass
squeeze_activation_unsqueeze_elimination_pass;
squeeze_activation_unsqueeze_elimination_pass.ApplyImpl(graph);
CustomSqueezeUnsqueezeEliminationPass
custom_squeeze_unsqueeze_elimination_pass;
custom_squeeze_unsqueeze_elimination_pass.ApplyImpl(graph);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(redundant_squeeze_unsqueeze_elimination_pass,
paddle::framework::ir::RedundantSqueezeUnsqueezeEliminationPass);
REGISTER_PASS_CAPABILITY(redundant_squeeze_unsqueeze_elimination_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("squeeze2", 0)
.LE("leaky_relu", 1)
.EQ("unsqueeze2", 0));
......@@ -531,6 +531,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"matmul_weight_trans_pass",
"map_matmulv2_to_matmul_xpu_pass",
"reshape2_matmul_xpu_fuse_pass",
"redundant_squeeze_unsqueeze_elimination_pass",
"fc_xpu_fuse_pass",
"conv2d_xpu_fuse_pass",
"add_activation_xpu_fuse_pass",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestXpuRedundantSqueezeUnsqueezeEliminationPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["leaky_relu"], (1e-5, 1e-5)
def sample_program_config(self, draw):
x_shape = draw(st.sampled_from([[1, 32, 1, 4]]))
alpha = 0.009999999776482582
axes = [2]
squeeze_op = OpConfig(
"squeeze2",
inputs={
"X": ["squeeze_input"],
},
outputs={"Out": ["squeeze_out"]},
axes=axes,
)
leaky_relu_op = OpConfig(
"leaky_relu",
inputs={
"X": ["squeeze_out"],
},
outputs={"Out": ["leaky_relu_out"]},
alpha=alpha,
)
unsqueeze_op = OpConfig(
"unsqueeze2",
inputs={
"X": ["leaky_relu_out"],
},
outputs={"Out": ["unsqueeze_out"]},
axes=axes,
)
ops = [squeeze_op, leaky_relu_op, unsqueeze_op]
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
program_config = ProgramConfig(
ops=ops,
inputs={
"squeeze_input": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
},
weights={},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
min_success_num=1,
passes=["redundant_squeeze_unsqueeze_elimination_pass"],
)
class TestXpuRedundantSqueezeUnsqueezeEliminationPass2(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["leaky_relu", "elementwise_add", "leaky_relu"], (
1e-5,
1e-5,
)
def sample_program_config(self, draw):
x_shape = draw(st.sampled_from([[1, 32, 1, 4]]))
alpha = 0.009999999776482582
axes = [2]
squeeze_op_1 = OpConfig(
"squeeze2",
inputs={
"X": ["squeeze_1_input"],
},
outputs={"Out": ["squeeze_1_out"]},
axes=axes,
)
leaky_relu_op_1 = OpConfig(
"leaky_relu",
inputs={
"X": ["squeeze_1_out"],
},
outputs={"Out": ["leaky_relu_1_out"]},
alpha=alpha,
)
squeeze_op_2 = OpConfig(
"squeeze2",
inputs={
"X": ["squeeze_2_input"],
},
outputs={"Out": ["squeeze_2_out"]},
axes=axes,
)
elementwise_add_op = OpConfig(
"elementwise_add",
inputs={"X": ["leaky_relu_1_out"], "Y": ["squeeze_2_out"]},
outputs={"Out": ["elementwise_add_out"]},
)
leaky_relu_op_2 = OpConfig(
"leaky_relu",
inputs={
"X": ["elementwise_add_out"],
},
outputs={"Out": ["leaky_relu_2_out"]},
alpha=alpha,
)
unsqueeze_op_1 = OpConfig(
"unsqueeze2",
inputs={
"X": ["leaky_relu_2_out"],
},
outputs={"Out": ["unsqueeze_1_out"]},
axes=axes,
)
unsqueeze_op_2 = OpConfig(
"unsqueeze2",
inputs={
"X": ["leaky_relu_2_out"],
},
outputs={"Out": ["unsqueeze_2_out"]},
axes=axes,
)
ops = [
squeeze_op_1,
leaky_relu_op_1,
squeeze_op_2,
elementwise_add_op,
leaky_relu_op_2,
unsqueeze_op_1,
unsqueeze_op_2,
]
def generate_data(shape):
return np.random.random(shape).astype(np.float32)
program_config = ProgramConfig(
ops=ops,
inputs={
"squeeze_1_input": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
"squeeze_2_input": TensorConfig(
data_gen=partial(generate_data, x_shape)
),
},
weights={},
outputs=["unsqueeze_1_out", "unsqueeze_2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
min_success_num=1,
passes=["redundant_squeeze_unsqueeze_elimination_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册