schedule_desc.h 3.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <absl/container/flat_hash_map.h>

#include <map>
#include <string>
#include <vector>

#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule_desc.pb.h"
#include "paddle/cinn/utils/registry.h"
#include "paddle/cinn/utils/type_defs.h"

namespace cinn {
namespace ir {

30 31 32 33 34 35
// A ScheduleDesc describe the scheduling process of an ir::ModuleExpr, it
// records all transform/getting operations executed by a corresponding
// ir::IRSchedule. A ScheduleDesc can be serialized to JSON format and saved to
// file. For deserializing, it can be re-applied to a new IRSchedule that is
// initialzied by a semantics-euqal original ir::ModuleExpr, and then achieves
// the same result.
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54

class IRSchedule;  // forward declartion to avoid cross-reference
class ScheduleDesc {
 public:
  // each operation executed through IRSchedule is recorded as a step
  struct Step {
    std::string type;  // step name
    absl::flat_hash_map<std::string, std::vector<Expr>> inputs;
    utils::AttributeMap attrs;
    std::vector<Expr> outputs;
    Step() = default;
    Step(std::string type_i,
         absl::flat_hash_map<std::string, std::vector<Expr>> inputs_i,
         utils::AttributeMap attrs_i,
         std::vector<Expr> outputs_i)
        : type(type_i), inputs(inputs_i), attrs(attrs_i), outputs(outputs_i) {}
  };

  /**
55 56
   * \brief Re-applied a scheduling process represented as a proto::ScheduleDesc
   * to a new IRSchedule object.
57 58
   * @param desc_proto The proto of the ScheduleDesc to be re-applied.
   * @param sch The original IRSchedule to be replayed the description on.
59 60
   * @param without_post_schedule Determine whether to delete the post
   * schedules.
61
   */
62 63 64 65
  static std::vector<Expr> ReplayWithProto(
      const proto::ScheduleDesc& desc_proto,
      IRSchedule* sch,
      bool without_post_schedule = false);
66 67 68 69 70 71 72 73 74 75 76 77 78 79

  ScheduleDesc() = default;

  ScheduleDesc(const std::vector<Step>& steps) : steps_(steps) {}

  ScheduleDesc(std::vector<Step>&& steps) : steps_(steps) {}

  // Append a new step
  void Append(Step&& step);

  // Pop the last step
  void Pop();

  /**
80 81
   * \brief Replay this description to a new IRSchedule that is initialzied by a
   * semantics-euqal original ModuleExpr.
82
   * @param schedule The original IRSchedule to be replayed the description on.
83 84
   * @param without_post_schedule Determine whether to delete the post
   * schedules.
85 86 87 88 89 90 91 92 93 94 95 96 97 98
   */
  void Replay(IRSchedule* schedule, bool without_post_schedule = false) const;

  // convert to a proto::ScheduleDesc object
  proto::ScheduleDesc ToProto() const;

  // return detail string of a ScheduleDesc for debug;
  std::string DebugString() const { return ToProto().DebugString(); }

  std::vector<Step> Steps() const { return steps_; }

  bool Empty() const { return steps_.empty(); }

  /**
99 100
   * \brief Fork this ScheduleDesc and update a step of the new ScheduleDesc
   * with a new decision.
101 102
   * @param step_idx The index of the step to be update.
   * @param decision The new decision.
103 104
   * @param without_post_schedule Determine whether to delete the post
   * schedules.
105 106
   * @return The new ScheduleDesc.
   */
107 108 109
  ScheduleDesc ForkAndUpdate(int step_idx,
                             utils::Attribute decision,
                             bool without_post_schedule) const;
110 111 112 113 114 115 116

 private:
  std::vector<Step> steps_;  // all operations are recorded in order.
};

}  // namespace ir
}  // namespace cinn