未验证 提交 24efc13b 编写于 作者: 周周周 提交者: GitHub

[Paddle Inference] remove redunant op (#54442)

* remove redunant
上级 a0d59f9d
...@@ -89,7 +89,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference) ...@@ -89,7 +89,7 @@ pass_library(conv_elementwise_add2_act_fuse_pass inference)
pass_library(conv_elementwise_add_fuse_pass inference) pass_library(conv_elementwise_add_fuse_pass inference)
pass_library(transpose_flatten_concat_fuse_pass inference) pass_library(transpose_flatten_concat_fuse_pass inference)
pass_library(inplace_op_var_pass inference) pass_library(inplace_op_var_pass inference)
pass_library(identity_scale_op_clean_pass base) pass_library(identity_op_clean_pass base)
pass_library(sync_batch_norm_pass base) pass_library(sync_batch_norm_pass base)
pass_library(runtime_context_cache_pass base) pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
...@@ -102,7 +102,6 @@ pass_library(delete_weight_dequant_linear_op_pass inference) ...@@ -102,7 +102,6 @@ pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference) pass_library(delete_dropout_op_pass inference)
pass_library(delete_concat_op_pass inference) pass_library(delete_concat_op_pass inference)
pass_library(delete_c_identity_op_pass inference)
pass_library(preln_residual_bias_fuse_pass inference) pass_library(preln_residual_bias_fuse_pass inference)
pass_library(constant_folding_pass inference) pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference) pass_library(auto_mixed_precision_pass inference)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_c_identity_op_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
void DeleteCIdentityOpPattern::operator()() {
auto any_op_out = pattern->NewNode(any_op_out_repr())
->assert_is_op_input("c_identity", "X")
->AsInput();
auto c_identity_op =
pattern->NewNode(c_identity_op_repr())->assert_is_op("c_identity");
auto c_identity_op_out = pattern->NewNode(c_identity_op_out_repr())
->assert_is_op_output("c_identity", "Out")
->AsIntermediate();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
c_identity_op->LinksFrom({any_op_out});
c_identity_op_out->LinksFrom({c_identity_op});
any_op2->LinksFrom({c_identity_op_out});
}
} // namespace patterns
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(c_identity_op); \
GET_IR_NODE(c_identity_op_out); \
GET_IR_NODE(any_op2);
void DeleteCIdentityOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_c_identity_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
patterns::DeleteCIdentityOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string c_identity_op_out_name = c_identity_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
if (arg_name.size() == 0) {
LOG(INFO) << "Delete c_identity op pass: can not find the input "
<< c_identity_op_out_name;
return;
}
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(),
name_m.second.end(),
c_identity_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != c_identity_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
}
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {c_identity_op, c_identity_op_out});
};
gpd(graph, handler);
}
DeleteCIdentityOpPass::DeleteCIdentityOpPass() {
AddOpCompat(OpCompat("c_identity"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_c_identity_op_pass,
paddle::framework::ir::DeleteCIdentityOpPass);
REGISTER_PASS_CAPABILITY(delete_c_identity_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().LE(
"c_identity", 1));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct DeleteCIdentityOpPattern : public PatternBase {
DeleteCIdentityOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_c_identity_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(c_identity_op);
PATTERN_DECL_NODE(c_identity_op_out);
PATTERN_DECL_NODE(any_op2);
};
} // namespace patterns
class Graph;
class DeleteCIdentityOpPass : public FusePassBase {
public:
DeleteCIdentityOpPass();
virtual ~DeleteCIdentityOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/identity_scale_op_clean_pass.h" #include "paddle/fluid/framework/ir/identity_op_clean_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
...@@ -23,17 +23,16 @@ namespace ir { ...@@ -23,17 +23,16 @@ namespace ir {
class Graph; class Graph;
void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("identity_scale_op_clean", graph); FusePassBase::Init("identity_scale_op_clean", graph);
// pre_op -> scale_in -> scale_op -> scale_out // pre_op -> useless_op_in -> useless_op -> useless_op_out
// -> // ->
// pre_op -> scale_out // pre_op -> useless_op_out
GraphPatternDetector detector; GraphPatternDetector detector;
auto scale_in = auto useless_op_in =
detector.mutable_pattern() detector.mutable_pattern()
->NewNode("scale_in") ->NewNode("useless_op_in")
->assert_is_op_input("scale")
->assert_has_n_outputs(1) ->assert_has_n_outputs(1)
->assert_var_not_persistable() ->assert_var_not_persistable()
->assert_more([](Node* x) { ->assert_more([](Node* x) {
...@@ -45,27 +44,51 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -45,27 +44,51 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
} }
return true; return true;
}); });
auto scale_op = detector.mutable_pattern()
->NewNode("scale_fuse")
->assert_is_op("scale")
->assert_op_attr<float>("scale", 1.)
->assert_op_attr<float>("bias", 0.);
auto scale_out = detector.mutable_pattern()
->NewNode("scale_out")
->assert_is_op_output("scale");
scale_op->LinksFrom({scale_in}).LinksTo({scale_out}); // This useless_op must have only one input and one output!
auto useless_op =
detector.mutable_pattern()
->NewNode("useless_op")
->assert_has_n_inputs(1)
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
if (!x->IsOp()) {
return false;
}
if (x->Op()->Type() == "scale") {
auto scale = x->Op()->GetAttrIfExists<float>("scale");
auto bias = x->Op()->GetAttrIfExists<float>("bias");
if (std::abs(bias) <= 1e-6 && std::abs(scale - 1) <= 1e-6) {
return true;
}
}
if (x->Op()->Type() == "cast") {
auto in_dtype = x->Op()->GetAttrIfExists<int>("in_dtype");
auto out_dtype = x->Op()->GetAttrIfExists<int>("out_dtype");
if (in_dtype == out_dtype) {
return true;
}
}
if (x->Op()->Type() == "c_identity") {
return true;
}
// you can add more cases here.
return false;
});
auto useless_op_out = detector.mutable_pattern()->NewNode("useless_op_out");
useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
int found_subgraph_count = 0; int found_subgraph_count = 0;
GraphPatternDetector::handle_t handler = GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) { [&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
Node* scale_op_var = subgraph.at(scale_op); Node* useless_op_var = subgraph.at(useless_op);
Node* scale_in_var = subgraph.at(scale_in); Node* useless_op_in_var = subgraph.at(useless_op_in);
Node* scale_out_var = subgraph.at(scale_out); Node* useless_op_out_var = subgraph.at(useless_op_out);
const std::string scale_in_name = scale_in_var->Name(); const std::string useless_op_in_name = useless_op_in_var->Name();
const std::string scale_out_name = scale_out_var->Name(); const std::string useless_op_out_name = useless_op_out_var->Name();
// Remove links in graph // Remove links in graph
GraphSafeRemoveNodes(graph, {scale_in_var, scale_op_var}); GraphSafeRemoveNodes(graph, {useless_op_in_var, useless_op_var});
// Modify pre_op_desc // Modify pre_op_desc
// Link pre_op directly to scale_out // Link pre_op directly to scale_out
for (auto& node : graph->Nodes()) { for (auto& node : graph->Nodes()) {
...@@ -76,16 +99,16 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -76,16 +99,16 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
auto names = out_var_map.second; auto names = out_var_map.second;
bool reset = false; bool reset = false;
for (size_t i = 0; i < names.size(); i++) { for (size_t i = 0; i < names.size(); i++) {
if (names[i] == scale_in_name) { if (names[i] == useless_op_in_name) {
reset = true; reset = true;
names[i] = scale_out_name; names[i] = useless_op_out_name;
break; break;
} }
} }
if (reset) { if (reset) {
op_desc->SetOutput(out_var_map.first, names); op_desc->SetOutput(out_var_map.first, names);
op_desc->Flush(); op_desc->Flush();
IR_NODE_LINK_TO(node, scale_out_var); IR_NODE_LINK_TO(node, useless_op_out_var);
break; break;
} }
} }
...@@ -102,9 +125,10 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const { ...@@ -102,9 +125,10 @@ void IdentityScaleOpCleanPass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(identity_scale_op_clean_pass, REGISTER_PASS(identity_op_clean_pass,
paddle::framework::ir::IdentityScaleOpCleanPass); paddle::framework::ir::IdentityOpCleanPass);
REGISTER_PASS_CAPABILITY(identity_scale_op_clean_pass) REGISTER_PASS_CAPABILITY(identity_op_clean_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ( paddle::framework::compatible::OpVersionComparatorCombination()
"scale", 0)); .EQ("scale", 0)
.LE("c_identity", 1));
...@@ -22,12 +22,12 @@ namespace ir { ...@@ -22,12 +22,12 @@ namespace ir {
class Graph; class Graph;
class IdentityScaleOpCleanPass : public FusePassBase { class IdentityOpCleanPass : public FusePassBase {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
private: private:
virtual ~IdentityScaleOpCleanPass() = default; virtual ~IdentityOpCleanPass() = default;
}; };
} // namespace ir } // namespace ir
......
...@@ -54,7 +54,7 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -54,7 +54,7 @@ static const std::vector<std::string> support_subgraph_passes = {
static const std::vector<std::string> xpu_support_subgraph_passes = { static const std::vector<std::string> xpu_support_subgraph_passes = {
"delete_dropout_op_pass", "delete_dropout_op_pass",
"delete_concat_op_pass", "delete_concat_op_pass",
"identity_scale_op_clean_pass", "identity_op_clean_pass",
"delete_op_device_pass", "delete_op_device_pass",
"constant_folding_pass", "constant_folding_pass",
"delete_elementwise_mul_op_pass", "delete_elementwise_mul_op_pass",
......
...@@ -94,12 +94,11 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -94,12 +94,11 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_quant_dequant_filter_op_pass", // "delete_quant_dequant_filter_op_pass", //
"trt_delete_weight_dequant_linear_op_pass", // "trt_delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
"identity_scale_op_clean_pass", // "identity_op_clean_pass", //
"add_support_int8_pass", // "add_support_int8_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"trt_embedding_eltwise_layernorm_fuse_pass", // "trt_embedding_eltwise_layernorm_fuse_pass", //
"preln_embedding_eltwise_layernorm_fuse_pass", // "preln_embedding_eltwise_layernorm_fuse_pass", //
"delete_c_identity_op_pass", //
"trt_multihead_matmul_fuse_pass_v2", // "trt_multihead_matmul_fuse_pass_v2", //
"trt_multihead_matmul_fuse_pass_v3", // "trt_multihead_matmul_fuse_pass_v3", //
"multihead_matmul_roformer_fuse_pass", // "multihead_matmul_roformer_fuse_pass", //
...@@ -175,7 +174,7 @@ const std::vector<std::string> kLiteSubgraphPasses({ ...@@ -175,7 +174,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
// running errors. After fusion operator supports low precision, delete this. // running errors. After fusion operator supports low precision, delete this.
const std::vector<std::string> kGpuLowerPrecisionPasses{ const std::vector<std::string> kGpuLowerPrecisionPasses{
"map_op_to_another_pass", "map_op_to_another_pass",
"identity_scale_op_clean_pass", "identity_op_clean_pass",
"simplify_with_basic_ops_pass", "simplify_with_basic_ops_pass",
"silu_fuse_pass", "silu_fuse_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
...@@ -222,7 +221,7 @@ const std::vector<std::string> kCINNCompilerPasses{ ...@@ -222,7 +221,7 @@ const std::vector<std::string> kCINNCompilerPasses{
GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"map_op_to_another_pass", // "map_op_to_another_pass", //
"identity_scale_op_clean_pass", // "identity_op_clean_pass", //
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
...@@ -511,7 +510,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { ...@@ -511,7 +510,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
passes_.assign({ passes_.assign({
"delete_dropout_op_pass", "delete_dropout_op_pass",
"delete_concat_op_pass", "delete_concat_op_pass",
"identity_scale_op_clean_pass", "identity_op_clean_pass",
"delete_repeated_ops_pass", "delete_repeated_ops_pass",
"delete_op_device_pass", "delete_op_device_pass",
"constant_folding_pass", "constant_folding_pass",
......
...@@ -56,7 +56,7 @@ class TestDeleteCIdentityPass(PassAutoScanTest): ...@@ -56,7 +56,7 @@ class TestDeleteCIdentityPass(PassAutoScanTest):
self.run_and_statis( self.run_and_statis(
max_examples=2, max_examples=2,
min_success_num=2, min_success_num=2,
passes=["delete_c_identity_op_pass"], passes=["identity_op_clean_pass"],
) )
......
...@@ -18,20 +18,10 @@ import hypothesis.strategies as st ...@@ -18,20 +18,10 @@ import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
class TestIdentityScaleCleanPass(PassAutoScanTest): class TestIdentityScaleCleanPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_trt_inference_config() config = self.create_inference_config(use_gpu=True)
config.enable_tensorrt_engine(
max_batch_size=8,
workspace_size=0,
min_subgraph_size=0,
precision_mode=paddle_infer.PrecisionType.Float32,
use_static=False,
use_calib_mode=False,
)
yield config, ['relu'], (1e-5, 1e-5) yield config, ['relu'], (1e-5, 1e-5)
def sample_program_config(self, draw): def sample_program_config(self, draw):
...@@ -61,9 +51,7 @@ class TestIdentityScaleCleanPass(PassAutoScanTest): ...@@ -61,9 +51,7 @@ class TestIdentityScaleCleanPass(PassAutoScanTest):
return program_config return program_config
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
max_examples=25, passes=["identity_scale_op_clean_pass"]
)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -156,7 +156,7 @@ class TrtConvertRangeStaticTest(TrtLayerAutoScanTest): ...@@ -156,7 +156,7 @@ class TrtConvertRangeStaticTest(TrtLayerAutoScanTest):
def generate_input2(): def generate_input2():
return np.array([1]).astype(np.int32) return np.array([1]).astype(np.int32)
for in_dtype in [2, 5]: for in_dtype in [2]:
self.in_dtype = in_dtype self.in_dtype = in_dtype
dics = [{}] dics = [{}]
ops_config = [ ops_config = [
......
...@@ -102,7 +102,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest): ...@@ -102,7 +102,7 @@ class TestDeleteRepeatedShapePass(PassAutoScanTest):
class TestDeleteRepeatedSlicePass(PassAutoScanTest): class TestDeleteRepeatedSlicePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True) config = self.create_inference_config(use_xpu=True)
yield config, ['slice', 'cast', 'cast', 'cast'], (1e-5, 1e-5) yield config, ['slice'], (1e-5, 1e-5)
def sample_program_config(self, draw): def sample_program_config(self, draw):
slice_x = draw( slice_x = draw(
...@@ -122,15 +122,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -122,15 +122,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice0_out"]}, outputs={"Out": ["slice0_out"]},
) )
cast_op0 = OpConfig(
"cast",
inputs={
"X": ["slice0_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast0_out"]},
)
slice_op1 = OpConfig( slice_op1 = OpConfig(
"slice", "slice",
inputs={ inputs={
...@@ -142,15 +133,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -142,15 +133,6 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice1_out"]}, outputs={"Out": ["slice1_out"]},
) )
cast_op1 = OpConfig(
"cast",
inputs={
"X": ["slice1_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast1_out"]},
)
slice_op2 = OpConfig( slice_op2 = OpConfig(
"slice", "slice",
inputs={ inputs={
...@@ -162,16 +144,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -162,16 +144,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
decrease_axis=[0], decrease_axis=[0],
outputs={"Out": ["slice2_out"]}, outputs={"Out": ["slice2_out"]},
) )
cast_op2 = OpConfig( ops = [slice_op0, slice_op1, slice_op2]
"cast",
inputs={
"X": ["slice2_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast2_out"]},
)
ops = [slice_op0, cast_op0, slice_op1, cast_op1, slice_op2, cast_op2]
program_config = ProgramConfig( program_config = ProgramConfig(
ops=ops, ops=ops,
...@@ -179,7 +152,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest): ...@@ -179,7 +152,7 @@ class TestDeleteRepeatedSlicePass(PassAutoScanTest):
inputs={ inputs={
"slice_x": TensorConfig(shape=slice_x), "slice_x": TensorConfig(shape=slice_x),
}, },
outputs=["cast0_out", "cast1_out", "cast2_out"], outputs=["slice0_out", "slice1_out", "slice2_out"],
) )
return program_config return program_config
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册