diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index b29130ff22eaf6a0a83db9ca75675c4f1273151d..44a10031b97eb09da964b2c9be49a2d703a3334d 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -49,6 +49,7 @@ USE_MIR_PASS(xpu_subgraph_pass); USE_MIR_PASS(int64_to_int32_pass); // USE_MIR_PASS(identity_cast_eliminate_pass); USE_MIR_PASS(mlu_subgraph_pass); +USE_MIR_PASS(identity_cast_eliminate_pass); USE_MIR_PASS(mlu_postprocess_pass); USE_MIR_PASS(weight_quantization_preprocess_pass); USE_MIR_PASS(quantized_op_attributes_inference_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index 230f107592bcebb2d70391614ab05cc38236c567..09db85c4093f274d220742845647c8133b173ae3 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -24,6 +24,7 @@ lite_cc_library(mir_passes fusion/__xpu__resnet_fuse_pass.cc fusion/__xpu__multi_encoder_fuse_pass.cc elimination/identity_scale_eliminate_pass.cc + elimination/identity_cast_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc static_kernel_pick_pass.cc variable_place_inference_pass.cc diff --git a/lite/core/mir/elimination/identity_cast_eliminate_pass.cc b/lite/core/mir/elimination/identity_cast_eliminate_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d11e3ef28149f248d64b642e629c2c0c7d9f0a1c --- /dev/null +++ b/lite/core/mir/elimination/identity_cast_eliminate_pass.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/pass.h" +#include "lite/core/mir/pass_registry.h" +#include "lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { + +namespace { + +class CastEliminator : public FuseBase { + public: + explicit CastEliminator(const int dtype) : dtype_(dtype) {} + void BuildPattern() override { + // the previous op's output need updat + auto* pre_op = OpNode("preop") + ->assert_is_not_op_type("conditional_block") + ->assert_is_not_op_type("cast"); + + auto* x = VarNode("x")->assert_is_op_input("cast", "X"); + auto* cast_op = OpNode("cast", "cast") + ->assert_op_attr("in_dtype", dtype_) + ->assert_op_attr("out_dtype", dtype_); + auto* out = VarNode("out")->assert_is_op_output("cast", "Out"); + *pre_op >> *x >> *cast_op >> *out; + // The pre_op will be eliminated, and a new output-updated op will insert. + } + + private: + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { + auto& pre_op = matched.at("preop")->AsStmt(); + auto op_info = *pre_op.op_info(); + + op_info.UpdateAllOutputs(matched.at("x")->AsArg().name, + matched.at("out")->AsArg().name); + pre_op.ResetOp(op_info, graph->valid_places()); + auto& cast_op = matched.at("cast")->AsStmt(); + auto cast_op_desc = *cast_op.op_info(); + auto in_dtype = cast_op_desc.GetAttr("in_dtype"); + auto out_dtype = cast_op_desc.GetAttr("out_dtype"); + // ====================== DEBUG INFO ========================= + VLOG(6) << "in_dtype : " << in_dtype; + VLOG(6) << "out_dtype : " << out_dtype; + // ====================== DEBUG END ========================= + GraphSafeRemoveNodes(graph, {matched.at("cast")}); + + IR_NODE_LINK_TO(matched.at("preop"), matched.at("out")); + } + int dtype_ = -1; +}; + +} // namespace + +class IdentityCastEliminatePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override { + // const int BOOL = 0; + // const int INT16 = 1; + const int INT32 = 2; + // const int INT64 = 3; + // const int FP16 = 4; + // const int FP32 = 5; + // const int FP64 = 6; + // const int UINT8 = 20; + // const int INT8 = 21; + CastEliminator eliminator_int32(INT32); + eliminator_int32(graph.get()); + } +}; + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(identity_cast_eliminate_pass, + paddle::lite::mir::IdentityCastEliminatePass) + .BindTargets({TARGET(kMLU)}); diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 79061a7e8160ef739a32416b637797ae8d60d5b6..ab5e6c8d8a09ed29842d3eb76c8fca1f1c4bf0af 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -91,9 +91,8 @@ class Optimizer { "bm_subgraph_pass", "rknpu_subgraph_pass", "int64_to_int32_pass", - // "identity_cast_eliminate_pass", + "identity_cast_eliminate_pass", "mlu_subgraph_pass", - "static_kernel_pick_pass", // pick original kernel from graph "variable_place_inference_pass", // inference arg/var's