// Copyright (c) 2022 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/ir/schedule_desc.h" #include #include #include #include #include "paddle/cinn/common/macros.h" #include "paddle/cinn/ir/ir_schedule.h" #include "paddle/cinn/utils/string.h" namespace cinn { namespace ir { // ------ Following codes are about `Apply` functions registry of variaous types of ScheduleDesc::Step class PackedStepContext; // uniformed function prototype of a scheduling operation in IRSchedule using StepApplyFunc = std::vector (*)(PackedStepContext*); // format the inputs, attrs, uniformed function of a scheduling step class StepKindInfo { public: // compatible for Registry::EntryType std::string name; // format: {"", "", ...} StepKindInfo& Inputs(std::vector&& inputs) { inputs_ = inputs; return *this; } // format: {"", "", ...} StepKindInfo& Attrs(std::vector&& attrs) { attrs_ = attrs; return *this; } // format: APPLY_FUNC_UNIFORM(...) StepKindInfo& SetApplyFn(StepApplyFunc&& func) { apply_func_ = func; return *this; } // execute the Apply function of this type std::vector Apply(PackedStepContext* context) const { return apply_func_(context); } private: friend class PackedStepContext; std::vector inputs_; std::vector attrs_; StepApplyFunc apply_func_{nullptr}; }; // StepKindInfo register for all scheduling steps class StepKindRegistry : public Registry { public: StepKindRegistry() = default; private: CINN_DISALLOW_COPY_AND_ASSIGN(StepKindRegistry); }; // PackedStepContext is the param of a uniformed `Apply` function, which is used to be an // auxiliary structure to interact with in/out arguments of the original scheduling function in IRSchedule class PackedStepContext { public: explicit PackedStepContext(const ScheduleDesc::Step& desc, const StepKindInfo* step_kind, IRSchedule* schedule) : ir_schedule_(schedule) { Build(desc, step_kind); } // get the pointer of current IRSchedule object IRSchedule* ScheduleHandler() const { return ir_schedule_; } // get the idx-th input whose signature is Expr Expr InputAt(size_t idx) const { CHECK_LT(idx, input_range_.size()) << "idx overranges"; const auto& range = input_range_.at(idx); CHECK(range.second - range.first == 1) << "not single param"; return inputs_[range.first]; } // get the idx-th input whose signature is `std::vector` std::vector InputsAt(size_t idx) const { CHECK_LT(idx, input_range_.size()) << "idx overranges"; const auto& range = input_range_.at(idx); std::vector results; for (size_t s = range.first; s < range.second; ++s) { results.emplace_back(inputs_[s]); } return results; } // get the idx-th attribute value with correct type template const AttrType& AttrAt(size_t idx) const { try { return absl::get(attrs_.at(idx)); } catch (absl::bad_variant_access& ex) { LOG(FATAL) << "Attribute cast error, idx:" << idx << ", get tpye:" << typeid(AttrType).name() << ", real index:" << attrs_.at(idx).index(); throw ex; } } private: void Build(const ScheduleDesc::Step& desc, const StepKindInfo* step_kind) { // build inputs size_t input_idx = 0; for (auto&& param_name : step_kind->inputs_) { auto arg_it = desc.inputs.find(param_name); CHECK(arg_it != desc.inputs.end()) << "Can't find param:" << param_name; auto&& args = arg_it->second; inputs_.insert(inputs_.end(), std::make_move_iterator(args.begin()), std::make_move_iterator(args.end())); input_range_.emplace_back(input_idx, input_idx + args.size()); input_idx += args.size(); } // build attrs size_t attr_idx = 0; for (auto&& attr_name : step_kind->attrs_) { auto attr_it = desc.attrs.find(attr_name); CHECK(attr_it != desc.attrs.end()) << "Can't find attribute:" << attr_name; attrs_.emplace_back(attr_it->second); ++attr_idx; } } IRSchedule* ir_schedule_; std::vector inputs_; std::vector> input_range_; std::vector attrs_; }; #define CINN_SPECIALIZE_ApplyCallHelper(attr_type) \ template \ struct ApplyCallHelper { \ template \ static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { \ using rf_attr_type = std::remove_reference::type; \ using rc_attr_type = std::remove_const::type; \ const auto& arg = ctx->AttrAt(attr_idx); \ return ApplyCallHelper::template Apply( \ ctx, std::forward(pargs)..., arg); \ } \ } template struct TypeTag {}; // used for converting a member function of the IRSchedule to be a free function // with the first parameter is a pointer to the IRSchedule. template struct FreeFuncConverter; template struct FreeFuncConverter { static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } }; template struct FreeFuncConverter { static Return Apply(IRSchedule* sch, Args... args) { return (sch->*impl_fn)(std::forward(args)...); } }; // used for formatting scheduling functions with variaous function signatures to be uniformed form template struct ApplyFuncImpl; template struct ApplyFuncImpl { static std::vector Apply(PackedStepContext* ctx) { return ApplyCallHelper>::template Apply<0, 0, 0>(ctx); } private: template struct ApplyCallHelper; // the signature of input parameters of a scheduling operation only can // be one of IRSchedule, Expr or std::vector template struct ApplyCallHelper { template static std::vector Apply(PackedStepContext* ctx) { static_assert(in_idx == 0, "IRSchedule* must be the first argument"); IRSchedule* ir_schedule = ctx->ScheduleHandler(); return ApplyCallHelper::template Apply(ctx, ir_schedule); } }; template struct ApplyCallHelper { template static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { auto arg = ctx->InputAt(in_idx - 1); return ApplyCallHelper::template Apply( ctx, std::forward(pargs)..., arg); } }; template struct ApplyCallHelper { template static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { auto arg = ctx->InputAt(in_idx - 1); return ApplyCallHelper::template Apply( ctx, std::forward(pargs)..., arg); } }; template struct ApplyCallHelper&, Tail...> { template static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { auto arg = ctx->InputsAt(in_idx - 1); return ApplyCallHelper::template Apply( ctx, std::forward(pargs)..., arg); } }; CINN_SPECIALIZE_ApplyCallHelper(bool); CINN_SPECIALIZE_ApplyCallHelper(int); CINN_SPECIALIZE_ApplyCallHelper(float); CINN_SPECIALIZE_ApplyCallHelper(const std::string&); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); CINN_SPECIALIZE_ApplyCallHelper(int64_t); CINN_SPECIALIZE_ApplyCallHelper(double); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); CINN_SPECIALIZE_ApplyCallHelper(const std::vector&); template struct ApplyReturnHelper; template struct ApplyReturnHelper { static std::vector Apply(Args... args) { impl_fn(std::forward(args)...); return {}; } }; template struct ApplyReturnHelper { static std::vector Apply(Args... args) { auto ret = impl_fn(std::forward(args)...); return {ret}; } }; template struct ApplyReturnHelper> { static std::vector Apply(Args... args) { return impl_fn(std::forward(args)...); } }; // end: base template template struct ApplyCallHelper> { template static std::vector Apply(PackedStepContext* ctx, PreviousArgs... pargs) { static_assert(out_idx == 0, "Output is exported from return value"); return ApplyReturnHelper::Apply(std::forward(pargs)...); } }; }; #define APPLY_FUNC_UNIFORM(...) ::cinn::ir::ApplyFuncImpl::Apply #define FREE_FUNCTION_CONVERTER(...) ::cinn::ir::FreeFuncConverter::Apply #define CINN_BUILD_STEP_KIND(TypeName) \ static ::cinn::ir::StepKindInfo& __step_kind_registrar_##TypeName = \ ::cinn::ir::StepKindRegistry::Global()->__REGISTER_OR_GET__(#TypeName) // register StepKindInfo for every type of scheduling operation // clang-format off CINN_BUILD_STEP_KIND(GetAllBlocks) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast (IRSchedule::*)() const>(&IRSchedule::GetAllBlocks)))); CINN_BUILD_STEP_KIND(GetChildBlocks) .Inputs({"expr"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast (IRSchedule::*)(const Expr&) const>(&IRSchedule::GetChildBlocks)))); CINN_BUILD_STEP_KIND(GetLoops) .Inputs({"block"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast (IRSchedule::*)(const Expr&) const>(&IRSchedule::GetLoops)))); CINN_BUILD_STEP_KIND(GetLoopsWithName) .Attrs({"block_name"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast (IRSchedule::*)(const std::string&) const>(&IRSchedule::GetLoops)))); CINN_BUILD_STEP_KIND(GetBlock) .Attrs({"block_name"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast(&IRSchedule::GetBlock)))); CINN_BUILD_STEP_KIND(Split) .Inputs({"loop", "factors"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast (IRSchedule::*)(const Expr&, const std::vector&)>(&IRSchedule::Split)))); CINN_BUILD_STEP_KIND(Fuse) .Inputs({"loops"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Fuse)))); CINN_BUILD_STEP_KIND(FuseWithName) .Attrs({"block_name", "loops_index"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Fuse)))); CINN_BUILD_STEP_KIND(FuseWithBlock) .Inputs({"block"}) .Attrs({"loops_index"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Fuse)))); CINN_BUILD_STEP_KIND(ComputeAt) .Inputs({"block", "loop"}) .Attrs({"keep_unit_loops"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ComputeAt))); CINN_BUILD_STEP_KIND(SimpleComputeAt) .Inputs({"block", "loop"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SimpleComputeAt))); CINN_BUILD_STEP_KIND(ReverseComputeAt) .Inputs({"block", "loop"}) .Attrs({"keep_unit_loops"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ReverseComputeAt))); CINN_BUILD_STEP_KIND(GetRootBlock) .Inputs({"expr"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::GetRootBlock))); CINN_BUILD_STEP_KIND(CacheRead) .Inputs({"block"}) .Attrs({"read_buffer_index", "memory_type"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::CacheRead))); CINN_BUILD_STEP_KIND(CacheWrite) .Inputs({"block"}) .Attrs({"write_buffer_index", "memory_type"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::CacheWrite))); CINN_BUILD_STEP_KIND(SyncThreads) .Inputs({"ir_node"}) .Attrs({"after_node"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SyncThreads))); CINN_BUILD_STEP_KIND(SetBuffer) .Inputs({"block"}) .Attrs({"memory_type", "fixed"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SetBuffer))); CINN_BUILD_STEP_KIND(Reorder) .Inputs({"loops"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Reorder)))); CINN_BUILD_STEP_KIND(ReorderWithBlock) .Inputs({"block"}) .Attrs({"loops_index"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Reorder)))); CINN_BUILD_STEP_KIND(ReorderWithName) .Attrs({"block_name", "loops_index"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER( static_cast&)>(&IRSchedule::Reorder)))); CINN_BUILD_STEP_KIND(Parallel) .Inputs({"loop"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Parallel))); CINN_BUILD_STEP_KIND(Vectorize) .Inputs({"loop"}) .Attrs({"factor"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Vectorize))); CINN_BUILD_STEP_KIND(Unroll) .Inputs({"loop"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unroll))); CINN_BUILD_STEP_KIND(ComputeInline) .Inputs({"schedule_block"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ComputeInline))); CINN_BUILD_STEP_KIND(ReverseComputeInline) .Inputs({"schedule_block"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::ReverseComputeInline))); CINN_BUILD_STEP_KIND(Bind) .Inputs({"loop"}) .Attrs({"thread_axis"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Bind))); CINN_BUILD_STEP_KIND(Rfactor) .Inputs({"rf_loop"}) .Attrs({"rf_axis"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Rfactor))); CINN_BUILD_STEP_KIND(MergeExprs) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::MergeExprs))); template void Annotate(IRSchedule* ir_sch, const Expr&, const std::string&, AttrType); template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, int value) { ir_sch->Annotate(block, key, value); } template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, bool value) { ir_sch->Annotate(block, key, value); } template <> void Annotate(IRSchedule* ir_sch, const Expr& block, const std::string& key, float value) { ir_sch->Annotate(block, key, value); } void AnnotateStringAttr(IRSchedule* ir_sch, const Expr& block, const std::string& key, const std::string& value) { ir_sch->Annotate(block, key, value); } CINN_BUILD_STEP_KIND(AnnotateIntAttr) .Inputs({"block"}) .Attrs({"key", "value"}) .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); CINN_BUILD_STEP_KIND(AnnotateBoolAttr) .Inputs({"block"}) .Attrs({"key", "value"}) .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); CINN_BUILD_STEP_KIND(AnnotateFloatAttr) .Inputs({"block"}) .Attrs({"key", "value"}) .SetApplyFn(APPLY_FUNC_UNIFORM(Annotate)); CINN_BUILD_STEP_KIND(AnnotateStringAttr) .Inputs({"block"}) .Attrs({"key", "value"}) .SetApplyFn(APPLY_FUNC_UNIFORM(AnnotateStringAttr)); CINN_BUILD_STEP_KIND(Unannotate) .Inputs({"block"}) .Attrs({"key"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::Unannotate))); CINN_BUILD_STEP_KIND(FlattenLoops) .Inputs({"loops"}) .Attrs({"force_flat"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::FlattenLoops))); CINN_BUILD_STEP_KIND(SamplePerfectTile) .Inputs({"loop"}) .Attrs({"n", "max_innermost_factor", "decision"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SamplePerfectTile))); CINN_BUILD_STEP_KIND(TagPostSchedule) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::TagPostSchedule))); CINN_BUILD_STEP_KIND(SampleCategorical) .Attrs({"candidates", "probs", "decision"}) .SetApplyFn(APPLY_FUNC_UNIFORM(FREE_FUNCTION_CONVERTER(&IRSchedule::SampleCategorical))); // clang-format on // ------ Following codes are about member function implement of the ScheduleDesc class void AttrVariantToProto(const utils::Attribute& attr, proto::ScheduleDesc_Attr* attr_proto) { #define SET_DESC_SINGLE_ITEM(index, built_type, proto_type, proto_field) \ case index: \ attr_proto->set_dtype(proto::ScheduleDesc_Attr_DataType_##proto_type); \ attr_proto->set_##proto_field(absl::get(attr)); \ break; #define SET_DESC_REPEATED_ITEM(index, built_type, proto_type, proto_field) \ case index: { \ attr_proto->set_dtype(proto::ScheduleDesc_Attr_DataType_##proto_type); \ const auto& values = absl::get(attr); \ attr_proto->mutable_##proto_field()->Reserve(values.size()); \ *attr_proto->mutable_##proto_field() = {values.begin(), values.end()}; \ break; \ } switch (attr.index()) { SET_DESC_SINGLE_ITEM(0, bool, BOOLEAN, b); SET_DESC_SINGLE_ITEM(1, float, FLOAT, f); SET_DESC_SINGLE_ITEM(2, int, INT, i); SET_DESC_SINGLE_ITEM(3, std::string, STRING, s); SET_DESC_REPEATED_ITEM(4, std::vector, BOOLEANS, bools); SET_DESC_REPEATED_ITEM(5, std::vector, INTS, ints); SET_DESC_REPEATED_ITEM(6, std::vector, FLOATS, floats); SET_DESC_REPEATED_ITEM(7, std::vector, STRINGS, strings); SET_DESC_SINGLE_ITEM(8, int64_t, LONG, l); SET_DESC_SINGLE_ITEM(9, double, DOUBLE, d); SET_DESC_REPEATED_ITEM(10, std::vector, LONGS, longs); SET_DESC_REPEATED_ITEM(11, std::vector, DOUBLES, doubles); default: LOG(FATAL) << "Invalid index:" << attr.index(); } #undef SET_DESC_SINGLE_ITEM #undef SET_DESC_REPEATED_ITEM } utils::Attribute AttrProtoToVariant(const proto::ScheduleDesc_Attr& attr) { utils::Attribute value; #define PARSE_DESC_SINGLE_ITEM(proto_type, proto_field, built_type) \ case proto::ScheduleDesc_Attr_DataType_##proto_type: \ value = built_type(attr.proto_field()); \ break; #define PARSE_DESC_REPEATED_ITEM(proto_type, proto_field, built_type) \ case proto::ScheduleDesc_Attr_DataType_##proto_type: \ value = built_type({attr.proto_field().begin(), attr.proto_field().end()}); \ break; switch (attr.dtype()) { PARSE_DESC_SINGLE_ITEM(BOOLEAN, b, bool); PARSE_DESC_SINGLE_ITEM(INT, i, int); PARSE_DESC_SINGLE_ITEM(FLOAT, f, float); PARSE_DESC_SINGLE_ITEM(STRING, s, std::string); PARSE_DESC_REPEATED_ITEM(BOOLEANS, bools, std::vector); PARSE_DESC_REPEATED_ITEM(INTS, ints, std::vector); PARSE_DESC_REPEATED_ITEM(FLOATS, floats, std::vector); PARSE_DESC_REPEATED_ITEM(STRINGS, strings, std::vector); PARSE_DESC_SINGLE_ITEM(LONG, l, int64_t); PARSE_DESC_SINGLE_ITEM(DOUBLE, d, double); PARSE_DESC_REPEATED_ITEM(LONGS, longs, std::vector); PARSE_DESC_REPEATED_ITEM(DOUBLES, doubles, std::vector); default: LOG(FATAL) << "Invalid type:" << attr.DebugString(); } #undef PARSE_DESC_SINGLE_ITEM #undef PARSE_DESC_REPEATED_ITEM return value; } // Expr hash functor, presents how to hash an Expr struct ExprHash { size_t operator()(const Expr& e) const { return std::hash()(e.ptr()); } }; // Expr equal functor, presents whether a Expr pair is equal struct ExprEqual { bool operator()(const Expr& lhs, const Expr& rhs) const { return lhs.get() == rhs.get(); } }; void ScheduleDesc::Append(Step&& step) { steps_.emplace_back(std::move(step)); } void ScheduleDesc::Pop() { if (!steps_.empty()) { steps_.pop_back(); } } void ScheduleDesc::Replay(IRSchedule* schedule, bool without_post_schedule) const { ReplayWithProto(this->ToProto(), schedule, without_post_schedule); } proto::ScheduleDesc ScheduleDesc::ToProto() const { // map each Expr to a formatted name (e1, e2, ...) absl::flat_hash_map expr2name; proto::ScheduleDesc desc_proto; for (auto&& step : steps_) { auto* step_proto = desc_proto.add_steps(); step_proto->set_type(step.type); // inputs of a step must refer to Exprs resulted by preceding steps for (auto&& param2exprs : step.inputs) { const std::string& param_name = param2exprs.first; auto* expr_desc = step_proto->add_inputs(); expr_desc->set_parameter(param_name); for (auto&& expr : param2exprs.second) { auto expr_it = expr2name.find(expr); CHECK(expr_it != expr2name.end()) << "Can't find expr of param_name: " << param_name; expr_desc->add_arguments(expr_it->second); } } // each output Expr is represented by a formatted name, to be refered by suceeding steps for (auto&& expr : step.outputs) { std::string local_name = "e" + std::to_string(expr2name.size()); expr2name.emplace(expr, local_name); step_proto->add_outputs(expr2name.at(expr)); } for (auto&& attr2value : step.attrs) { auto* attr_proto = step_proto->add_attrs(); const auto& attr_value = attr2value.second; VLOG(5) << "Attr.index:" << attr_value.index(); attr_proto->set_name(attr2value.first); AttrVariantToProto(attr_value, attr_proto); } } return desc_proto; } std::vector ScheduleDesc::ReplayWithProto(const proto::ScheduleDesc& desc_proto, IRSchedule* sch, bool without_post_schedule) { VLOG(4) << "proto::ScheduleDesc:\n" << desc_proto.DebugString(); if (desc_proto.steps().empty()) { LOG(WARNING) << "Input proto::ScheduleDesc is empty"; return {}; } // map a formatted name (e1, e2, ...) to an Expr absl::flat_hash_map name2expr; std::vector last_outputs; // resotre each scheduling step and apply to the new IRSchedule object for (auto&& step_proto : desc_proto.steps()) { VLOG(4) << "Replay step:\n" << step_proto.DebugString(); ScheduleDesc::Step step; step.type = step_proto.type(); CHECK(!step.type.empty()) << "Name of StepKind is empty"; if (without_post_schedule && step.type == "TagPostSchedule") { break; } const StepKindInfo* step_kind = StepKindRegistry::Global()->Find(step.type); CHECK(step_kind) << "Can't find StepKind:" << step.type; for (auto&& param2args : step_proto.inputs()) { for (auto&& arg : param2args.arguments()) { auto arg_it = name2expr.find(arg); CHECK(arg_it != name2expr.end()) << "Cant't find argument:" << arg; step.inputs[param2args.parameter()].emplace_back(arg_it->second); } } for (auto&& attr : step_proto.attrs()) { step.attrs[attr.name()] = AttrProtoToVariant(attr); } PackedStepContext context(step, step_kind, sch); step.outputs = step_kind->Apply(&context); CHECK_EQ(step_proto.outputs().size(), step.outputs.size()) << "Output size not matched"; for (size_t i = 0; i < step.outputs.size(); ++i) { name2expr[step_proto.outputs(i)] = step.outputs.at(i); } last_outputs = std::move(step.outputs); } return last_outputs; } ScheduleDesc ScheduleDesc::ForkAndUpdate(int step_idx, utils::Attribute decision, bool without_post_schedule) const { int n_valid_step = 0; if (!without_post_schedule) { n_valid_step = steps_.size(); } else { for (const auto& step : steps_) { if (step.type != "TagPostSchedule") { ++n_valid_step; } else { break; } } } std::vector new_steps(steps_.begin(), steps_.begin() + n_valid_step); new_steps[step_idx].attrs["decision"] = decision; return ScheduleDesc(std::move(new_steps)); } } // namespace ir } // namespace cinn