code_generator_helper.h 5.5 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

15 16
#pragma once

17
#include <sstream>
18 19
#include <string>
#include <unordered_map>
20
#include <unordered_set>
21 22
#include <vector>

23 24 25
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/type_defs.h"
#include "paddle/fluid/framework/var_desc.h"
26 27
#include "paddle/fluid/platform/enforce.h"

28 29 30
namespace paddle {
namespace framework {
namespace ir {
31 32
namespace fusion_group {

33 34 35
static inline std::string ArgName(int index) {
  return "arg" + std::to_string(index);
}
36

37 38 39
static inline std::string TmpName(int index) {
  return "tmp" + std::to_string(index);
}
40

41 42 43 44
static inline std::string VarName(int index) {
  return "arg" + std::to_string(index) + "[idx]";
}

45 46
class OperationExpression {
 public:
47
  explicit OperationExpression(std::string op_type, std::vector<int> input_ids,
48 49 50 51 52 53 54
                               std::vector<int> output_ids,
                               std::string rhs_type, std::string lhs_type)
      : op_type_(op_type),
        input_ids_(input_ids),
        output_ids_(output_ids),
        rhs_type_(rhs_type),
        lhs_type_(lhs_type) {}
55

56 57 58
  std::string GetOpType() const { return op_type_; }
  std::vector<int> GetInputIds() const { return input_ids_; }
  std::vector<int> GetOutputIds() const { return output_ids_; }
59 60
  std::string GetRHSType() const { return rhs_type_; }
  std::string GetLHSType() const { return lhs_type_; }
61 62
  void SetAttr(AttributeMap attr) { attr_ = attr; }
  AttributeMap GetAttr() { return attr_; }
63
  // Check whether this operation type is supported in OperationMap.
64
  bool IsSupport() const;
65

66
  std::string GetExpression(std::unordered_set<int>* used) const;
67

68
 private:
69
  // TODO(wangchao): make offset more flexible we add stride and basic offset
70 71
  std::string GetRHS(std::unordered_set<int>* used,
                     size_t exprs_index = 0) const;
72
  std::string GetLHS(size_t i = 0) const;
73 74

 private:
75
  std::string op_type_;
76 77
  std::vector<int> input_ids_;
  std::vector<int> output_ids_;
78
  AttributeMap attr_;
79 80
  std::string rhs_type_;
  std::string lhs_type_;
81 82
};

83 84 85 86 87
class TemplateVariable {
 public:
  void Add(std::string identifier, std::string expression) {
    strings_[identifier] = expression;
  }
88

89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
  void Remove(std::string identifier, std::string expression) {
    for (auto it = strings_.begin(); it != strings_.end();) {
      if (it->first == identifier) {
        it = strings_.erase(it);
      } else {
        it++;
      }
    }
  }

  std::unordered_map<std::string, std::string> Get() { return strings_; }

 private:
  std::unordered_map<std::string, std::string> strings_;
};
104

105 106 107 108 109 110 111 112 113
class CodeTemplate {
 public:
  CodeTemplate() = default;
  explicit CodeTemplate(std::string template_str) {
    template_str_ = template_str;
  }

  std::string Format(TemplateVariable template_var) {
    std::string ret = template_str_;
114
    std::unordered_map<std::string, bool> found;
115

116
    // Word begins with "$" in template_str will be replaced.
117 118 119 120 121
    for (size_t i = 0; i < ret.size(); i++) {
      auto pos = i;
      char c = ret[pos];

      if (c == '$') {
122 123 124 125 126 127
        for (auto iter : template_var.Get()) {
          std::string keyword = iter.first;
          if (ret.substr(pos + 1, keyword.size()) == keyword) {
            found[keyword] = true;
            ret.replace(pos, keyword.size() + 1, iter.second);
            break;
128 129 130 131 132
          }
        }
      }
    }

133 134
    for (auto iter : template_var.Get()) {
      PADDLE_ENFORCE_NE(found.find(iter.first), found.end(),
135 136
                        platform::errors::PreconditionNotMet(
                            "Keyword %s in template is not set.", iter.first));
137 138
    }

139 140
    return EmitIndents(ret);
  }
141

142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
  std::string EmitIndents(std::string str) {
    std::string ret = str;
    int space_num = 0;
    auto space_char = ' ';
    for (size_t i = 0; i < ret.size(); i++) {
      auto pos = i;
      char c = ret[pos];
      if (c == '\n') {
        size_t next_pos = pos + 1;
        while (next_pos < ret.size() && ret[next_pos] == space_char) {
          next_pos++;
        }
        space_num = next_pos - pos - 1;
      }
      if (c == ';' && (pos + 1 < ret.size()) && ret[pos + 1] != '\n') {
        auto insert_pos = pos + 1;
        std::string insert_str = "\n" + std::string(space_num, space_char);
        ret.insert(insert_pos, insert_str);
        space_num = 0;
      }
    }

    return ret;
  }

 private:
  std::string template_str_;
};

171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
static std::string DebugString(const OperationExpression& expr) {
  std::stringstream ret;
  ret << "Op(" << expr.GetOpType() << "), inputs:{";
  auto input_ids = expr.GetInputIds();
  for (size_t i = 0; i < input_ids.size(); ++i) {
    if (i != 0) {
      ret << ",";
    }
    ret << expr.GetInputIds()[i];
  }
  ret << "}, outputs:{";
  auto output_ids = expr.GetOutputIds();
  for (size_t i = 0; i < output_ids.size(); ++i) {
    if (i != 0) {
      ret << ",";
    }
    ret << expr.GetOutputIds()[i];
  }
  ret << "}";
  return ret.str();
}

193
}  // namespace fusion_group
194 195 196
}  // namespace ir
}  // namespace framework
}  // namespace paddle