// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/op_lowering_impl.h" #include "paddle/cinn/hlir/framework/op_lowering_impl_base.h" #include "paddle/cinn/lang/packed_func.h" #ifndef CINN_WITH_ONLY #include "paddle/cinn/hlir/framework/new_ir/op_lowering_impl.h" #endif namespace cinn { namespace hlir { namespace framework { using common::Target; using GroupPtr = std::shared_ptr; template class OpLowerer { public: explicit OpLowerer(OpLowererImplBase* impl) { impl_.reset(impl); } ~OpLowerer() {} std::vector Lower(const T& group, bool apply_op_schedule = true, bool apply_group_schedule = true) { return impl_->Lower(group, apply_op_schedule, apply_group_schedule); } private: std::shared_ptr> impl_; }; template OpLowerer CreateOpLowerer(const absl::flat_hash_map&, const absl::flat_hash_map&, const Target&); template <> inline OpLowerer CreateOpLowerer( const absl::flat_hash_map& type_dict, const absl::flat_hash_map& shape_dict, const Target& target) { auto* impl_base = new OpLowererImpl(type_dict, shape_dict, target); return OpLowerer(impl_base); } #ifndef CINN_WITH_ONLY template OpLowerer CreateOpLowerer(const Target&); template <> inline OpLowerer CreateOpLowerer(const Target& target) { auto* impl_base = new newir::OpLowererImpl(target); return OpLowerer(impl_base); } #endif } // namespace framework } // namespace hlir } // namespace cinn