未验证 提交 552ed8d8 编写于 作者: C csy0225 提交者: GitHub

Delete repeat ops add gather squeeze unsqueeze (#55371)

上级 bc153701
......@@ -129,6 +129,7 @@ pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(delete_elementwise_mul_op_pass inference)
pass_library(delete_repeated_ops_pass inference)
pass_library(fused_continuous_same_ops_pass inference)
pass_library(sigmoid_elementmul_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
......
......@@ -101,18 +101,18 @@ class DeleteRepeatedOpsPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;
private:
void DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const;
void DeleteRepeatedOps(ir::Graph* graph,
const std::string& op_type,
std::function<std::string(Node*)> gen_op_key_fn) const;
const std::string name_scope_{"delete_repeated_ops_pass"};
mutable int delete_op_count{0};
};
void DeleteRepeatedOpsPass::DeleteRepeatedOps(
ir::Graph* graph,
const std::string& op_type,
std::function<std::string(OpDesc*)> gen_op_key_fn) const {
std::function<std::string(Node*)> gen_op_key_fn) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, op_type);
......@@ -140,7 +140,7 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps(
}
}
if (out_op_is_invalid) continue;
auto attr_key = gen_op_key_fn(op->Op());
auto attr_key = gen_op_key_fn(op);
ops_map[attr_key].push_back(op);
}
for (auto iter = ops_map.begin(); iter != ops_map.end();) {
......@@ -173,16 +173,18 @@ void DeleteRepeatedOpsPass::DeleteRepeatedOps(
};
gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated " << op_type
<< " ops";
}
}
std::string GenShapeAttrKey(OpDesc* slice_op_desc) { return ""; }
std::string GenShapeAttrKey(Node* shape_op_node) { return ""; }
std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
std::string GenSliceAttrKey(Node* slice_op_node) {
std::string attr_key;
auto slice_op_desc = slice_op_node->Op();
auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
auto ends = slice_op_desc->GetAttrIfExists<std::vector<int>>("ends");
auto axes = slice_op_desc->GetAttrIfExists<std::vector<int>>("axes");
......@@ -207,21 +209,24 @@ std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
return attr_key;
}
std::string GenCastAttrKey(OpDesc* cast_op_desc) {
std::string GenCastAttrKey(Node* cast_op_node) {
auto cast_op_desc = cast_op_node->Op();
auto in_dtype = cast_op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = cast_op_desc->GetAttrIfExists<int>("out_dtype");
return "in_dtype_" + std::to_string(in_dtype) + "_out_dtype_" +
std::to_string(out_dtype);
}
std::string GenAddAttrKey(OpDesc* add_op_desc) {
std::string GenAddAttrKey(Node* add_op_node) {
auto add_op_desc = add_op_node->Op();
std::string x_name = add_op_desc->Input("X")[0];
std::string y_name = add_op_desc->Input("Y")[0];
auto axis = add_op_desc->GetAttrIfExists<int>("axis");
return x_name + "_" + y_name + "_axis_" + std::to_string(axis);
}
std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
std::string GenScaleAttrKey(Node* scale_op_node) {
auto scale_op_desc = scale_op_node->Op();
auto scale = scale_op_desc->GetAttrIfExists<float>("scale");
auto bias = scale_op_desc->GetAttrIfExists<float>("bias");
auto bias_after_scale =
......@@ -230,17 +235,53 @@ std::string GenScaleAttrKey(OpDesc* scale_op_desc) {
"_bias_after_scale_" + std::to_string(bias_after_scale);
}
std::string GenGatherAttrKey(Node* gather_op_node) {
std::string input_names{""};
for (auto input_var : gather_op_node->inputs) {
input_names += input_var->Var()->Name();
}
auto gather_op_desc = gather_op_node->Op();
auto axis = gather_op_desc->GetAttrIfExists<int>("axis");
return "axis_" + std::to_string(axis) + "_input_names_" + input_names;
}
std::string GenSqueeze2AttrKey(Node* squeeze2_op_node) {
auto squeeze2_op_desc = squeeze2_op_node->Op();
auto axes = squeeze2_op_desc->GetAttrIfExists<std::vector<int>>("axes");
std::string attr_key{""};
attr_key += "axes_";
for (auto axis : axes) {
attr_key += std::to_string(axis) + "_";
}
return attr_key;
}
void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
int repeat_time = 0;
int total_delete_op_count = 0;
// This pass needs to loop run until there are no nodes in the graph that need
// to be deleted.
while (true) {
delete_op_count = 0;
DeleteRepeatedOps(graph, "shape", GenShapeAttrKey);
DeleteRepeatedOps(graph, "slice", GenSliceAttrKey);
DeleteRepeatedOps(graph, "cast", GenCastAttrKey);
DeleteRepeatedOps(graph, "elementwise_add", GenAddAttrKey);
DeleteRepeatedOps(graph, "scale", GenScaleAttrKey);
DeleteRepeatedOps(graph, "gather", GenGatherAttrKey);
DeleteRepeatedOps(graph, "squeeze2", GenSqueeze2AttrKey);
DeleteRepeatedOps(graph, "unsqueeze2", GenSqueeze2AttrKey);
LOG(INFO) << "Round " << repeat_time++
<< ": delete op counts: " << delete_op_count;
total_delete_op_count += delete_op_count;
if (delete_op_count == 0) {
break; // No node need to delete.
}
}
LOG(INFO) << "Total delete op counts: " << total_delete_op_count;
}
} // namespace ir
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct ContinuousSameOpsPattern : public PatternBase {
ContinuousSameOpsPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type);
PATTERN_DECL_NODE(first_in_var_node);
PATTERN_DECL_NODE(first_out_var_node);
PATTERN_DECL_NODE(second_out_var_node);
// declare op node's name
PATTERN_DECL_NODE(first_op_node);
PATTERN_DECL_NODE(second_op_node);
std::string op_type_;
};
ContinuousSameOpsPattern::ContinuousSameOpsPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type)
: PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
auto* first_in_var_node =
pattern->NewNode(first_in_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_input(op_type_, "X")
->AsInput()
->assert_more([&](Node* node) {
// assert pre op type is not same.
auto input_nodes = node->inputs;
if (input_nodes.size() != 1) return false;
if (!input_nodes.empty() && input_nodes[0]->IsOp() &&
input_nodes[0]->Op()->Type() == op_type_) {
return false;
}
return true;
});
auto* first_op_node =
pattern->NewNode(first_op_node_repr())->assert_is_op(op_type_);
auto* first_out_var_node = pattern->NewNode(first_out_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_output(op_type_, "Out")
->assert_has_n_outputs(1);
first_op_node->LinksFrom({first_in_var_node}).LinksTo({first_out_var_node});
auto* second_op_node =
pattern->NewNode(second_op_node_repr())->assert_is_op(op_type_);
auto* second_out_var_node = pattern->NewNode(second_out_var_node_repr())
->assert_var_not_persistable()
->assert_is_op_output(op_type_, "Out")
->AsOutput();
second_op_node->LinksFrom({first_out_var_node})
.LinksTo({second_out_var_node});
}
} // namespace patterns
/*
Fused continuous same ops into one.
Origin graph:
input
|
|
unsqueeze2
|
|
unsqueeze2
|
|
unsqueeze2
|
|
out
After:
input
|
|
unsqueeze2
|
|
out
*/
class FusedContinuousSameOpsPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
void FusedReshapeOps(ir::Graph* graph) const;
void FusedUnsqueezeOps(ir::Graph* graph) const;
const std::string name_scope_{"fused_continuous_same_ops_pass"};
mutable int delete_op_count{0};
};
void FusedContinuousSameOpsPass::FusedReshapeOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ContinuousSameOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "reshape2");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle fused continuous reshape ops.";
GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
second_out_var_node, second_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern);
auto first_node_attr_shape =
first_op_node->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (first_node_attr_shape.empty()) return;
auto second_node_attr_shape =
second_op_node->Op()->GetAttrIfExists<std::vector<int>>("shape");
if (second_node_attr_shape.empty()) return;
second_op_node->Op()->RenameInput(first_out_var_node->Name(),
first_in_var_node->Name());
IR_NODE_LINK_TO(first_in_var_node, second_op_node);
GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node});
delete_counts++;
};
gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated "
<< "reshape2"
<< " ops";
}
}
void FusedContinuousSameOpsPass::FusedUnsqueezeOps(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::ContinuousSameOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "unsqueeze2");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle fused continuous unsqueeze ops.";
GET_IR_NODE_FROM_SUBGRAPH(first_in_var_node, first_in_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_out_var_node, first_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(
second_out_var_node, second_out_var_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(first_op_node, first_op_node, pattern);
GET_IR_NODE_FROM_SUBGRAPH(second_op_node, second_op_node, pattern);
auto first_node_attr_axes =
first_op_node->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (first_node_attr_axes.empty()) return;
auto second_node_attr_axes =
second_op_node->Op()->GetAttrIfExists<std::vector<int>>("axes");
if (second_node_attr_axes.empty()) return;
second_op_node->Op()->RenameInput(first_out_var_node->Name(),
first_in_var_node->Name());
second_node_attr_axes.insert(second_node_attr_axes.begin(),
first_node_attr_axes.begin(),
first_node_attr_axes.end());
second_op_node->Op()->SetAttr("axes", second_node_attr_axes);
IR_NODE_LINK_TO(first_in_var_node, second_op_node);
GraphSafeRemoveNodes(graph, {first_op_node, first_out_var_node});
delete_counts++;
};
gpd(graph, handler);
delete_op_count += delete_counts;
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated "
<< "unsqueeze2"
<< " ops";
}
}
void FusedContinuousSameOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int repeat_time = 0;
int total_delete_op_count = 0;
// This pass needs to loop run until there are no nodes in the graph that need
// to be deleted.
while (true) {
delete_op_count = 0;
FusedReshapeOps(graph);
FusedUnsqueezeOps(graph);
LOG(INFO) << "Round " << repeat_time++
<< ": delete op counts: " << delete_op_count;
total_delete_op_count += delete_op_count;
if (delete_op_count == 0) {
LOG(INFO) << "--- no nodes need to delete --- break";
break; // No node need to delete.
}
}
LOG(INFO) << "Total delete op counts: " << total_delete_op_count;
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(fused_continuous_same_ops_pass,
paddle::framework::ir::FusedContinuousSameOpsPass);
REGISTER_PASS_CAPABILITY(fused_continuous_same_ops_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"reshape2", 0))
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"unsqueeze2", 0));
......@@ -507,8 +507,9 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_assign_op_pass",
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_op_clean_pass",
"delete_repeated_ops_pass",
"identity_op_clean_pass",
"fused_continuous_same_ops_pass",
"reshape_unstack_concat_fuse_pass",
"delete_op_device_pass",
"constant_folding_pass",
......
......@@ -13,8 +13,10 @@
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
......@@ -380,5 +382,350 @@ class TestDeleteRepeatedScalePass(PassAutoScanTest):
)
class TestDeleteRepeatedSqueezePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'squeeze2', 'relu', 'relu', 'relu'], (
1e-5,
1e-5,
)
def sample_program_config(self, draw):
scale_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
scale_x[0] = 1
axis = 0
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
squeeze_op0 = OpConfig(
"squeeze2",
inputs={
"X": ["scale0_out"],
},
axes=[axis],
outputs={"Out": ["squeeze0_out"]},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["squeeze0_out"],
},
outputs={"Out": ["relu0_out"]},
)
scale_op1 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale1_out"]},
)
squeeze_op1 = OpConfig(
"squeeze2",
inputs={
"X": ["scale1_out"],
},
axes=[axis],
outputs={"Out": ["squeeze1_out"]},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["squeeze1_out"],
},
outputs={"Out": ["relu1_out"]},
)
scale_op2 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale2_out"]},
)
squeeze_op2 = OpConfig(
"squeeze2",
inputs={
"X": ["scale2_out"],
},
axes=[axis],
outputs={"Out": ["squeeze2_out"]},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["squeeze2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [
scale_op0,
squeeze_op0,
relu_op0,
scale_op1,
squeeze_op1,
relu_op1,
scale_op2,
squeeze_op2,
relu_op2,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
class TestDeleteRepeatedUnSqueezePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'unsqueeze2', 'relu', 'relu', 'relu'], (
1e-5,
1e-5,
)
def sample_program_config(self, draw):
scale_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
axis = 0
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
unsqueeze_op0 = OpConfig(
"unsqueeze2",
inputs={
"X": ["scale0_out"],
},
axes=[axis],
outputs={"Out": ["unsqueeze0_out"]},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["unsqueeze0_out"],
},
outputs={"Out": ["relu0_out"]},
)
scale_op1 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale1_out"]},
)
unsqueeze_op1 = OpConfig(
"unsqueeze2",
inputs={
"X": ["scale1_out"],
},
axes=[axis],
outputs={"Out": ["unsqueeze1_out"]},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["unsqueeze1_out"],
},
outputs={"Out": ["relu1_out"]},
)
scale_op2 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale2_out"]},
)
unsqueeze_op2 = OpConfig(
"unsqueeze2",
inputs={
"X": ["scale2_out"],
},
axes=[axis],
outputs={"Out": ["unsqueeze2_out"]},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["unsqueeze2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [
scale_op0,
unsqueeze_op0,
relu_op0,
scale_op1,
unsqueeze_op1,
relu_op1,
scale_op2,
unsqueeze_op2,
relu_op2,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
class TestDeleteRepeatedGatherPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'gather', 'relu', 'relu', 'relu'], (1e-5, 1e-5)
def sample_program_config(self, draw):
scale_x = draw(
st.lists(
st.integers(min_value=3, max_value=20), min_size=2, max_size=4
)
)
axis = 0
def generate_index(*args, **kwargs):
return np.array([0]).astype(np.int64)
gather_index = np.array([0]).astype(np.int64)
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
gather_op0 = OpConfig(
"gather",
inputs={"X": ["scale0_out"], "Index": ["gather_index"]},
axis=axis,
outputs={"Out": ["gather0_out"]},
)
relu_op0 = OpConfig(
"relu",
inputs={
"X": ["gather0_out"],
},
outputs={"Out": ["relu0_out"]},
)
scale_op1 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale1_out"]},
)
gather_op1 = OpConfig(
"gather",
inputs={"X": ["scale1_out"], "Index": ["gather_index"]},
axis=axis,
outputs={"Out": ["gather1_out"]},
)
relu_op1 = OpConfig(
"relu",
inputs={
"X": ["gather1_out"],
},
outputs={"Out": ["relu1_out"]},
)
scale_op2 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale2_out"]},
)
gather_op2 = OpConfig(
"gather",
inputs={"X": ["scale2_out"], "Index": ["gather_index"]},
axis=axis,
outputs={"Out": ["gather2_out"]},
)
relu_op2 = OpConfig(
"relu",
inputs={
"X": ["gather2_out"],
},
outputs={"Out": ["relu2_out"]},
)
ops = [
scale_op0,
gather_op0,
relu_op0,
scale_op1,
gather_op1,
relu_op1,
scale_op2,
gather_op2,
relu_op2,
]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
"gather_index": TensorConfig(data_gen=partial(generate_index)),
},
outputs=["relu0_out", "relu1_out", "relu2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestFusedSameUnSqueezePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'unsqueeze2'], (1e-5, 1e-5)
def sample_program_config(self, draw):
scale_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=1, max_size=3
)
)
first_unsqueeze_axis = 0
second_unsqueeze_axis = 1
third_unsqueeze_axis = 2
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
unsqueeze_op0 = OpConfig(
"unsqueeze2",
inputs={
"X": ["scale0_out"],
},
axes=[first_unsqueeze_axis],
outputs={"Out": ["unsqueeze0_out"]},
)
unsqueeze_op1 = OpConfig(
"unsqueeze2",
inputs={
"X": ["unsqueeze0_out"],
},
axes=[second_unsqueeze_axis],
outputs={"Out": ["unsqueeze1_out"]},
)
unsqueeze_op2 = OpConfig(
"unsqueeze2",
inputs={
"X": ["unsqueeze1_out"],
},
axes=[third_unsqueeze_axis],
outputs={"Out": ["unsqueeze2_out"]},
)
ops = [scale_op0, unsqueeze_op0, unsqueeze_op1, unsqueeze_op2]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
},
outputs=["unsqueeze2_out"],
)
return program_config
class TestFusedSameReshapePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['scale', 'reshape2'], (1e-5, 1e-5)
def sample_program_config(self, draw):
scale_x = draw(
st.sampled_from([[8, 16], [16, 32], [64, 16], [16, 8], [16, 16]])
)
first_reshape_shape = [-1, 16, 4]
second_reshape_shape = [-1, 8]
scale_op0 = OpConfig(
"scale",
inputs={
"X": ["scale_x"],
},
scale=2.0,
bias=1.0,
bias_after_scale=True,
outputs={"Out": ["scale0_out"]},
)
reshape_op0 = OpConfig(
"reshape2",
inputs={
"X": ["scale0_out"],
},
shape=first_reshape_shape,
outputs={"Out": ["reshape0_out"]},
)
reshape_op1 = OpConfig(
"reshape2",
inputs={
"X": ["reshape0_out"],
},
shape=second_reshape_shape,
outputs={"Out": ["reshape1_out"]},
)
ops = [scale_op0, reshape_op0, reshape_op1]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"scale_x": TensorConfig(shape=scale_x),
},
outputs=["reshape1_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
min_success_num=5,
passes=["fused_continuous_same_ops_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册