diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 12457c2e3c1d1ee35c59a38b7273d15925e4f27a..9fa7527827ae4ab80239a939066a9f87eee8658c 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -128,6 +128,7 @@ pass_library(dense_fc_to_sparse_pass inference) pass_library(dense_multihead_matmul_to_sparse_pass inference) pass_library(delete_cast_op_pass inference) pass_library(delete_elementwise_mul_op_pass inference) +pass_library(delete_repeated_ops_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) diff --git a/paddle/fluid/framework/ir/delete_cast_op_pass.cc b/paddle/fluid/framework/ir/delete_cast_op_pass.cc index 3bf2e53e40533bbb5529acb3cc8c52a1375f0e98..9db286927362b53a6d2a04efee9911dea495c760 100644 --- a/paddle/fluid/framework/ir/delete_cast_op_pass.cc +++ b/paddle/fluid/framework/ir/delete_cast_op_pass.cc @@ -19,6 +19,8 @@ #include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/enforce.h" +#include "paddle/phi/kernels/assign_kernel.h" +#include "paddle/phi/kernels/cast_kernel.h" namespace phi { class DenseTensor; @@ -623,6 +625,93 @@ int DeleteCastOpPass::ApplyCastScatterPass(ir::Graph* graph) const { return found_subgraph_count; } +namespace patterns { +struct CastLookupTablePattern : public PatternBase { + CastLookupTablePattern(PDPattern* pattern, const std::string& name_scope); + + // declare operator node's name + PATTERN_DECL_NODE(lookup_table); + PATTERN_DECL_NODE(cast); + // declare variable node's name + PATTERN_DECL_NODE(lookup_table_w); + PATTERN_DECL_NODE(lookup_table_out); + PATTERN_DECL_NODE(cast_out); +}; + +CastLookupTablePattern::CastLookupTablePattern(PDPattern* pattern, + const std::string& name_scope) + : PatternBase(pattern, name_scope, name_scope) { + auto* lookup_table_w = pattern->NewNode(lookup_table_w_repr()) + ->assert_is_op_input("lookup_table_v2", "W") + ->assert_is_persistable_var(); + auto* lookup_table = + pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table_v2"); + auto* lookup_table_out = pattern->NewNode(lookup_table_out_repr()) + ->assert_is_op_output("lookup_table_v2", "Out") + ->assert_is_op_input("cast", "X") + ->assert_has_n_outputs(1); + auto* cast = + pattern->NewNode(cast_repr()) + ->assert_is_op("cast") + ->assert_more([](Node* node) { + auto* op_desc = node->Op(); + auto in_dtype = op_desc->GetAttrIfExists("in_dtype"); + auto out_dtype = op_desc->GetAttrIfExists("out_dtype"); + return in_dtype == static_cast(proto::VarType::FP32) && + out_dtype == static_cast(proto::VarType::FP16); + }); + auto* cast_out = + pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out"); + + lookup_table->LinksFrom({lookup_table_w}).LinksTo({lookup_table_out}); + cast->LinksFrom({lookup_table_out}).LinksTo({cast_out}); +} +} // namespace patterns + +int DeleteCastOpPass::ApplyCastLookupTablePass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::CastLookupTablePattern pattern(gpd.mutable_pattern(), name_scope_); + + int found_subgraph_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle ApplyCastLookupTablePass fuse"; + GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast, cast, pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table_w, lookup_table_w, pattern); + GET_IR_NODE_FROM_SUBGRAPH(lookup_table_out, lookup_table_out, pattern); + GET_IR_NODE_FROM_SUBGRAPH(cast_out, cast_out, pattern); + auto* scope = param_scope(); + + auto* w_tensor = + scope->Var(lookup_table_w->Name())->GetMutable(); + lookup_table_w->Var()->SetDataType(proto::VarType::FP16); + if (w_tensor->dtype() != phi::DataType::FLOAT16) { + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + phi::DenseTensor w_fp32_tensor; + w_fp32_tensor.Resize(w_tensor->dims()); + w_fp32_tensor.set_type(w_tensor->dtype()); + phi::AssignKernel(*cpu_ctx, *w_tensor, &w_fp32_tensor); + w_tensor->set_type(phi::DataType::FLOAT16); + phi::CastKernel( + *cpu_ctx, w_fp32_tensor, phi::DataType::FLOAT16, w_tensor); + } + + for (auto* next_op : cast_out->outputs) { + next_op->Op()->RenameInput(cast_out->Name(), lookup_table_out->Name()); + IR_NODE_LINK_TO(lookup_table_out, next_op); + } + + std::unordered_set delete_nodes{cast, cast_out}; + GraphSafeRemoveNodes(graph, delete_nodes); + found_subgraph_count++; + }; + + gpd(graph, handler); + return found_subgraph_count; +} + namespace patterns { struct CastPattern : public PatternBase { CastPattern(PDPattern* pattern, const std::string& name_scope); @@ -718,6 +807,15 @@ void DeleteCastOpPass::ApplyImpl(ir::Graph* graph) const { << " cast_scatter_cast subgraph"; } + found_subgraph_count = 0; + for (size_t i = 0; i < graph->SubGraphsSize(); i++) { + found_subgraph_count += ApplyCastLookupTablePass(graph->GetSubGraph(i)); + } + if (found_subgraph_count > 0) { + LOG(INFO) << "--- delete " << found_subgraph_count + << " lookup_table_cast subgraph"; + } + found_subgraph_count = 0; for (size_t i = 0; i < graph->SubGraphsSize(); i++) { found_subgraph_count += ApplyCastPass(graph->GetSubGraph(i)); diff --git a/paddle/fluid/framework/ir/delete_cast_op_pass.h b/paddle/fluid/framework/ir/delete_cast_op_pass.h index 37132af07e17fd86b53f2e5de09b4a0edfcf650c..7aa18415e0fb46f2e3fb97e0785cbe5c0e65b8f1 100644 --- a/paddle/fluid/framework/ir/delete_cast_op_pass.h +++ b/paddle/fluid/framework/ir/delete_cast_op_pass.h @@ -124,6 +124,21 @@ class DeleteCastOpPass : public FusePassBase { */ int ApplyCastScatterPass(ir::Graph* graph) const; + /* + Origin subgraph: + ids w(fp32) + \ / + lookup_table + | + cast(fp32->fp16) + + Optimized subgraph: + ids w(fp16) + \ / + lookup_table + */ + int ApplyCastLookupTablePass(ir::Graph* graph) const; + // Delete cast if its "in_dtype" is the same as "out_dtype" int ApplyCastPass(ir::Graph* graph) const; diff --git a/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc b/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc index 1885f945840332485cb02fdb0b8bbf1627f465d4..11d1339f35d249046dd22dfa9445013736b85d25 100644 --- a/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc +++ b/paddle/fluid/framework/ir/delete_cast_op_pass_test.cc @@ -20,6 +20,16 @@ namespace paddle { namespace framework { namespace ir { +void AddVarToScope(Scope* param_scope, + const std::string& name, + const DDim& dims) { + auto* tensor = param_scope->Var(name)->GetMutable(); + tensor->Resize(dims); + auto* cpu_ctx = static_cast( + platform::DeviceContextPool::Instance().Get(phi::CPUPlace())); + cpu_ctx->Alloc(tensor); +} + VarDesc* Data(paddle::framework::BlockDesc* block, std::string name, std::vector shape = {}, @@ -255,6 +265,36 @@ TEST(ApplyCastScatterPass, basic) { cast_num_in_graph)); } +TEST(ApplyCastLookupTablePass, basic) { + paddle::framework::ProgramDesc program; + auto* block = program.MutableBlock(0); + auto* lookup_table_w = Data(block, "lookup_table_w", {1}, true); + auto* lookup_table_out = Data(block, "scatter_out", {1}); + OpDesc* lookup_table = block->AppendOp(); + lookup_table->SetType("lookup_table_v2"); + lookup_table->SetInput("W", {lookup_table_w->Name()}); + lookup_table->SetOutput("Out", {lookup_table_out->Name()}); + auto* cast_out = AddCast(block, lookup_table_out, 5, 4); + OpDesc* relu = block->AppendOp(); + relu->SetType("relu"); + relu->SetInput("X", {cast_out->Name()}); + relu->SetOutput("Out", {"relu_out"}); + + std::unique_ptr graph(new ir::Graph(program)); + auto scope = new Scope(); + AddVarToScope(scope, lookup_table_w->Name(), {1}); + graph->Set("__param_scope__", scope); + auto pass = PassRegistry::Instance().Get("delete_cast_op_pass"); + pass->Apply(graph.get()); + int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast"); + PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"), + 0, + platform::errors::PreconditionNotMet( + "graph should have 0 cast after delete_cast_op_pass, " + "but actually has %d.", + cast_num_in_graph)); +} + TEST(ApplyCastPass, basic) { paddle::framework::ProgramDesc program; auto* block = program.MutableBlock(0); diff --git a/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf5bb15c207e5d00c7b0f423752e5aa26509ca99 --- /dev/null +++ b/paddle/fluid/framework/ir/delete_repeated_ops_pass.cc @@ -0,0 +1,255 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/platform/enforce.h" + +namespace phi { +class DenseTensor; +} // namespace phi + +namespace paddle { +namespace framework { +class Scope; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace framework { +namespace ir { +namespace patterns { + +struct VarWithRepeatedOpsPattern : public PatternBase { + VarWithRepeatedOpsPattern(PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type); + + // declare variable node's name + PATTERN_DECL_NODE(in_var); + + std::string op_type_; +}; + +VarWithRepeatedOpsPattern::VarWithRepeatedOpsPattern( + PDPattern* pattern, + const std::string& name_scope, + const std::string& op_type) + : PatternBase(pattern, name_scope, name_scope), op_type_(op_type) { + pattern->NewNode(in_var_repr()) + ->assert_is_var() + ->assert_more([&](Node* node) { + auto out_nodes = node->outputs; + if (out_nodes.size() <= 1) return false; + int op_counts = 0; + for (auto* next_op : out_nodes) { + if (next_op->Name() == op_type_) { + op_counts++; + } + } + return op_counts > 1; + }); +} + +} // namespace patterns + +/* +Delete repeated ops, for example: +Origin subgraph: + (input_variable) + / | \ ... + shape shape shape ... + | | | ... + op0 op1 op2 ... + +Optimized subgraph: + (input_variable) + | + shape + / | \ ... + op0 op1 op2 ... +*/ +class DeleteRepeatedOpsPass : public FusePassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; + + private: + int DeleteShapePass(ir::Graph* graph) const; + + int DeleteSlicePass(ir::Graph* graph) const; + + const std::string name_scope_{"delete_repeated_ops_pass"}; +}; + +int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::VarWithRepeatedOpsPattern pattern( + gpd.mutable_pattern(), name_scope_, "shape"); + + int delete_counts = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DeleteShapePass"; + GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); + + std::vector shapes; + for (auto* next_op : in_var->outputs) { + if (next_op->Name() != "shape") continue; + bool shape_out_has_control_flow_ops = false; + for (auto* shape_out_op : next_op->outputs[0]->outputs) { + if (shape_out_op->Name() == "while" || + shape_out_op->Name() == "conditional_block") { + shape_out_has_control_flow_ops = true; + break; + } + } + if (!shape_out_has_control_flow_ops) { + shapes.push_back(next_op); + } + } + if (shapes.size() <= 1) return; + + auto* first_shape_out = shapes[0]->outputs[0]; + auto first_shape_out_name = first_shape_out->Name(); + std::unordered_set delete_nodes; + for (size_t i = 1; i < shapes.size(); i++) { + auto* cur_shape = shapes[i]; + auto* cur_shape_out = cur_shape->outputs[0]; + auto cur_shape_out_name = cur_shape_out->Name(); + for (auto* shape_out_op : cur_shape_out->outputs) { + shape_out_op->Op()->Rename(cur_shape_out_name, first_shape_out_name); + IR_NODE_LINK_TO(first_shape_out, shape_out_op); + } + delete_nodes.insert(cur_shape); + delete_nodes.insert(cur_shape_out); + delete_counts++; + } + + GraphSafeRemoveNodes(graph, delete_nodes); + }; + + gpd(graph, handler); + return delete_counts; +} + +std::string GenSliceAttrKey(OpDesc* slice_op_desc) { + std::string attr_key; + auto starts = slice_op_desc->GetAttrIfExists>("starts"); + auto ends = slice_op_desc->GetAttrIfExists>("ends"); + auto axes = slice_op_desc->GetAttrIfExists>("axes"); + attr_key += "starts_"; + for (auto start : starts) { + attr_key += std::to_string(start) + "_"; + } + attr_key += "ends_"; + for (auto end : ends) { + attr_key += std::to_string(end) + "_"; + } + attr_key += "axes_"; + for (auto axis : axes) { + attr_key += std::to_string(axis) + "_"; + } + return attr_key; +} + +int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const { + GraphPatternDetector gpd; + patterns::VarWithRepeatedOpsPattern pattern( + gpd.mutable_pattern(), name_scope_, "slice"); + + int delete_counts = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* graph) { + VLOG(4) << "handle DeleteSlicePass"; + GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern); + + std::map> slice_ops; + for (auto* next_op : in_var->outputs) { + if (next_op->Name() != "slice") continue; + auto* slice = next_op; + bool slice_out_has_control_flow_ops = false; + for (auto* slice_out_op : slice->outputs[0]->outputs) { + if (slice_out_op->Name() == "while" || + slice_out_op->Name() == "conditional_block") { + slice_out_has_control_flow_ops = true; + break; + } + } + if (slice_out_has_control_flow_ops) continue; + auto attr_key = GenSliceAttrKey(slice->Op()); + slice_ops[attr_key].push_back(slice); + } + for (auto iter = slice_ops.begin(); iter != slice_ops.end();) { + if (iter->second.size() <= 1) { + iter = slice_ops.erase(iter); + } else { + iter++; + } + } + + for (auto iter : slice_ops) { + auto slices = iter.second; + auto* first_slice_out = slices[0]->outputs[0]; + auto first_slice_out_name = first_slice_out->Name(); + std::unordered_set delete_nodes; + for (size_t i = 1; i < slices.size(); i++) { + auto* cur_slice = slices[i]; + auto* cur_slice_out = cur_slice->outputs[0]; + auto cur_slice_out_name = cur_slice_out->Name(); + for (auto* slice_out_op : cur_slice_out->outputs) { + slice_out_op->Op()->Rename(cur_slice_out_name, first_slice_out_name); + IR_NODE_LINK_TO(first_slice_out, slice_out_op); + } + delete_nodes.insert(cur_slice); + delete_nodes.insert(cur_slice_out); + delete_counts++; + } + GraphSafeRemoveNodes(graph, delete_nodes); + } + }; + + gpd(graph, handler); + return delete_counts; +} + +void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const { + PADDLE_ENFORCE_NOT_NULL( + graph, platform::errors::PreconditionNotMet("graph should not be null.")); + Init(name_scope_, graph); + + int delete_counts = DeleteShapePass(graph); + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " repeated shape ops"; + } + + delete_counts = DeleteSlicePass(graph); + if (delete_counts > 0) { + LOG(INFO) << "--- delete " << delete_counts << " repeated slice ops"; + } +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_repeated_ops_pass, + paddle::framework::ir::DeleteRepeatedOpsPass); + +REGISTER_PASS_CAPABILITY(delete_repeated_ops_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination().EQ( + "shape", 0)); diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index b30241cd2a83a7a3d1ad517138db95d74bbefe8a..d1aa3bd492d7f89e1686676016bfebc079fa8d8f 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -512,6 +512,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) { "delete_dropout_op_pass", "delete_concat_op_pass", "identity_scale_op_clean_pass", + "delete_repeated_ops_pass", "delete_op_device_pass", "constant_folding_pass", "delete_elementwise_mul_op_pass", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_repeated_ops_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_repeated_ops_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e789ac8b6116ca9e4950807aaac840217cc53c --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_xpu_delete_repeated_ops_pass.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import hypothesis.strategies as st +from auto_scan_test import PassAutoScanTest +from program_config import OpConfig, ProgramConfig, TensorConfig + + +class TestDeleteRepeatedShapePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['shape', 'cast', 'cast', 'cast'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + x_shape = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + + shape_op0 = OpConfig( + "shape", + inputs={ + "Input": ["shape_x"], + }, + outputs={"Out": ["shape0_out"]}, + ) + cast_op0 = OpConfig( + "cast", + inputs={ + "X": ["shape0_out"], + }, + in_dtype=2, + out_dtype=5, + outputs={"Out": ["cast0_out"]}, + ) + shape_op1 = OpConfig( + "shape", + inputs={ + "Input": ["shape_x"], + }, + outputs={"Out": ["shape1_out"]}, + ) + cast_op1 = OpConfig( + "cast", + inputs={ + "X": ["shape1_out"], + }, + in_dtype=2, + out_dtype=5, + outputs={"Out": ["cast1_out"]}, + ) + shape_op2 = OpConfig( + "shape", + inputs={ + "Input": ["shape_x"], + }, + outputs={"Out": ["shape2_out"]}, + ) + cast_op2 = OpConfig( + "cast", + inputs={ + "X": ["shape2_out"], + }, + in_dtype=2, + out_dtype=5, + outputs={"Out": ["cast2_out"]}, + ) + ops = [shape_op0, cast_op0, shape_op1, cast_op1, shape_op2, cast_op2] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "shape_x": TensorConfig(shape=x_shape), + }, + outputs=["cast0_out", "cast1_out", "cast2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + +class TestDeleteRepeatedSlicePass(PassAutoScanTest): + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_xpu=True) + yield config, ['slice', 'cast', 'cast', 'cast'], (1e-5, 1e-5) + + def sample_program_config(self, draw): + slice_x = draw( + st.lists( + st.integers(min_value=1, max_value=20), min_size=2, max_size=4 + ) + ) + + slice_op0 = OpConfig( + "slice", + inputs={ + "Input": ["slice_x"], + }, + starts=[0], + ends=[1], + axes=[0], + decrease_axis=[0], + outputs={"Out": ["slice0_out"]}, + ) + cast_op0 = OpConfig( + "cast", + inputs={ + "X": ["slice0_out"], + }, + in_dtype=5, + out_dtype=5, + outputs={"Out": ["cast0_out"]}, + ) + slice_op1 = OpConfig( + "slice", + inputs={ + "Input": ["slice_x"], + }, + starts=[0], + ends=[1], + axes=[0], + decrease_axis=[0], + outputs={"Out": ["slice1_out"]}, + ) + cast_op1 = OpConfig( + "cast", + inputs={ + "X": ["slice1_out"], + }, + in_dtype=5, + out_dtype=5, + outputs={"Out": ["cast1_out"]}, + ) + slice_op2 = OpConfig( + "slice", + inputs={ + "Input": ["slice_x"], + }, + starts=[0], + ends=[1], + axes=[0], + decrease_axis=[0], + outputs={"Out": ["slice2_out"]}, + ) + cast_op2 = OpConfig( + "cast", + inputs={ + "X": ["slice2_out"], + }, + in_dtype=5, + out_dtype=5, + outputs={"Out": ["cast2_out"]}, + ) + ops = [slice_op0, cast_op0, slice_op1, cast_op1, slice_op2, cast_op2] + + program_config = ProgramConfig( + ops=ops, + weights={}, + inputs={ + "slice_x": TensorConfig(shape=slice_x), + }, + outputs=["cast0_out", "cast1_out", "cast2_out"], + ) + return program_config + + def test(self): + self.run_and_statis( + quant=False, + max_examples=25, + passes=["delete_repeated_ops_pass"], + ) + + +if __name__ == "__main__": + unittest.main()