code_generator.cc 12.0 KB
Newer Older
1
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14 15

#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
16
#include <sstream>
17
#include <unordered_set>
18
#include "paddle/fluid/framework/ir/fusion_group/code_generator_helper.h"
19
#include "paddle/fluid/framework/ir/fusion_group/cuda_resources.h"
20
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
21

22 23 24
namespace paddle {
namespace framework {
namespace ir {
25
namespace fusion_group {
26

27
std::string ExtractDataType(const std::vector<Node*>& nodes) {
28 29 30 31 32 33 34 35 36 37 38
  std::string dtype_str = "";
  for (const auto* n : nodes) {
    if (n && n->IsVar() && n->Var()) {
      // The data type of all inputs/outputs must be the same, which are
      //  checked when detecting the subgraph.
      auto dtype = n->Var()->GetDataType();
      if (dtype == proto::VarType::FP32) {
        dtype_str = "float";
      } else if (dtype == proto::VarType::FP64) {
        dtype_str = "double";
      } else if (dtype == proto::VarType::FP16) {
39
        dtype_str = "__half";
40 41 42
      }
      break;
    }
43 44 45 46 47
  }

  return dtype_str;
}

48 49 50 51
CodeGenerator::CodeGenerator() {
  // Only support elementwise operations now.
  code_templates_.resize(1);

52
  CodeTemplate elementwise_t(cuda_kernel_template_1d);
53
  code_templates_[0] = elementwise_t;
54 55
}

56 57
std::string CodeGenerator::Generate(SubGraph* subgraph) {
  std::vector<OperationExpression> expressions = ConvertToExpressions(subgraph);
58
  return Generate(subgraph->GetFuncName(), expressions);
59 60
}

61 62 63 64 65 66 67 68 69 70
static bool HasInput(Node* n, std::string name) {
  PADDLE_ENFORCE_EQ(n && n->IsOp() && n->Op(), true,
                    platform::errors::InvalidArgument(
                        "Expected node %p to be an operator node.", n));
  std::vector<std::string> input_names = n->Op()->InputNames();
  std::unordered_set<std::string> input_names_set(input_names.begin(),
                                                  input_names.end());
  return input_names_set.find(name) != input_names_set.end();
}

71 72 73
std::vector<OperationExpression> CodeGenerator::ConvertToExpressions(
    SubGraph* subgraph) {
  std::unordered_map<std::string, int> var_ids = EncodeVarNodes(subgraph);
74 75
  std::vector<Node*> intermediate_out_nodes =
      subgraph->GetIntermediateOutVarNodes();
76 77 78 79
  std::vector<OperationExpression> expressions;
  for (auto* node : subgraph->SortedNodes()) {
    if (node && node->IsOp() && node->Op()) {
      auto* op = node->Op();
80
      AttributeMap attr = *(op->MutableAttrMap());
81

82
      // Input ids should be set in fixed order, like:
83 84
      //  - X, Y in forward operations
      //  - X, Y, Out, out@GRAD in backward operations
85
      std::vector<int> input_ids;
86 87
      std::string op_name = op->Type();
      auto operation = OperationMap::Instance().Get(op_name);
88 89
      std::vector<std::string> input_names = operation.input_names;

90
      for (auto& name : input_names) {
91 92
        // Some input vars are not used in grad ops, such as
        // "elementwise_add_grad", where "X", "Y" and "Out" are not used.
93 94 95 96 97 98 99 100
        if ((HasInput(node, name) && op->Input(name).size() >= 1U)) {
          for (size_t i = 0; i < op->Input(name).size(); i++) {
            PADDLE_ENFORCE_NE(
                var_ids.find(op->Input(name)[i]), var_ids.end(),
                platform::errors::InvalidArgument(
                    "Input(%s) of operation %s is not set.", name, op->Type()));
            input_ids.push_back(var_ids[op->Input(name)[i]]);
          }
101 102 103 104
        } else {
          input_ids.push_back(-1);
        }
      }
105

106 107 108 109 110
      // Output ids should be set in fixed order, like:
      //  - dx, dy in backward operations
      std::vector<int> output_ids;
      std::vector<std::string> output_names =
          OperationMap::Instance().Get(op->Type()).output_names;
111
      std::unordered_map<int, bool> intermediate_state;
112

113
      for (auto& name : output_names) {
114 115 116 117
        PADDLE_ENFORCE_NE(
            var_ids.find(op->Output(name)[0]), var_ids.end(),
            platform::errors::InvalidArgument(
                "Output(%s) of operation %s is not set.", name, op->Type()));
118
        output_ids.push_back(var_ids[op->Output(name)[0]]);
119 120 121 122 123 124 125 126
        bool enable_intermediate = false;
        for (auto* n : intermediate_out_nodes) {
          if (n->Name() == op->Output(name)[0]) {
            enable_intermediate = true;
            break;
          }
        }
        intermediate_state[var_ids[op->Output(name)[0]]] = enable_intermediate;
127
      }
128 129 130

      std::string lhs_type = ExtractDataType(node->outputs);
      std::string rhs_type = ExtractDataType(node->inputs);
131 132 133
      auto expression =
          OperationExpression(node->Name(), input_ids, output_ids, rhs_type,
                              lhs_type, intermediate_state);
134 135
      expression.SetAttr(attr);
      expressions.push_back(expression);
136 137 138 139 140
    }
  }
  return expressions;
}

141 142
// In order to get the right result of expression, we need to calculate and
// store the expression as suffix Expressions using vector.
143
std::string CodeGenerator::Generate(
144
    std::string func_name,
145
    const std::vector<OperationExpression>& expressions) {
146
  // TODO(liuyiqun): Check whether all expressions are elementwise operations.
147 148
  std::set<int> input_ids = std::move(DistilInputIds(expressions));
  std::set<int> output_ids = std::move(DistilOutputIds(expressions));
149 150
  std::set<int> intermediate_ids =
      std::move(DistilIntermediateIds(expressions));
151 152
  std::unordered_map<int, std::string> dtypes =
      std::move(DistilDtypes(expressions));
153 154
  TemplateVariable template_var;
  template_var.Add("func_name", func_name);
155 156
  template_var.Add("parameters", EmitParameters(input_ids, output_ids,
                                                intermediate_ids, dtypes));
157
  template_var.Add("compute_body",
158 159
                   EmitComputeBody(expressions, input_ids, output_ids,
                                   intermediate_ids, dtypes));
160 161 162 163 164 165 166

  std::set<std::string> all_dtype;
  for (const auto& type : dtypes) {
    all_dtype.insert(type.second);
  }
  std::string predefined_cuda_functions = "";
  if (all_dtype.find("float") != all_dtype.end() &&
167
      all_dtype.find("__half") == all_dtype.end()) {
168 169 170 171 172
    predefined_cuda_functions += predefined_cuda_functions_fp32;
  }
  if (all_dtype.find("double") != all_dtype.end()) {
    predefined_cuda_functions += predefined_cuda_functions_fp64;
  }
173
  if (all_dtype.find("__half") != all_dtype.end()) {
174
    predefined_cuda_functions += predefined_cuda_functions_fp16;
175
  }
176 177 178
  return predefined_cuda_functions + code_templates_[0].Format(template_var);
}

179 180
std::set<int> CodeGenerator::DistilInputIds(
    const std::vector<OperationExpression>& expressions) {
181
  std::set<int> input_ids;
182
  // Use std::set to remove the reptead id and get a ordered list.
183 184
  for (size_t i = 0; i < expressions.size(); i++) {
    for (auto id : expressions[i].GetInputIds()) {
185 186 187
      if (id >= 0) {
        input_ids.insert(id);
      }
188
    }
189 190 191 192 193 194 195 196 197
  }
  return input_ids;
}

std::set<int> CodeGenerator::DistilOutputIds(
    const std::vector<OperationExpression>& expressions) {
  std::set<int> output_ids;
  // Use std::set to remove the reptead id and get a ordered list.
  for (size_t i = 0; i < expressions.size(); i++) {
198 199 200 201
    for (auto id : expressions[i].GetOutputIds()) {
      output_ids.insert(id);
    }
  }
202 203 204
  return output_ids;
}

205 206 207 208 209 210 211 212 213 214 215 216 217
std::set<int> CodeGenerator::DistilIntermediateIds(
    const std::vector<OperationExpression>& expressions) {
  std::set<int> intermediate_ids;
  // Use std::set to remove the reptead id and get a ordered list.
  for (size_t i = 0; i < expressions.size(); i++) {
    for (auto id : expressions[i].GetOutputIds()) {
      auto intermediate_state = expressions[i].GetIntermediateState();
      if (intermediate_state[id]) intermediate_ids.insert(id);
    }
  }
  return intermediate_ids;
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
std::unordered_map<int, std::string> CodeGenerator::DistilDtypes(
    const std::vector<OperationExpression>& expressions) {
  std::unordered_map<int, std::string> dtypes;
  for (const auto& expression : expressions) {
    for (auto id : expression.GetInputIds()) {
      auto dtype = expression.GetRHSType();
      if (dtypes.find(id) == dtypes.end()) {
        dtypes[id] = dtype;
      } else {
        PADDLE_ENFORCE_EQ(
            dtypes[id], dtype,
            platform::errors::PreconditionNotMet(
                "In fusion group, Same Node id must have same date type"));
      }
    }
    for (auto id : expression.GetOutputIds()) {
      auto dtype = expression.GetLHSType();
      if (dtypes.find(id) == dtypes.end()) {
        dtypes[id] = dtype;
      } else {
        PADDLE_ENFORCE_EQ(
            dtypes[id], dtype,
            platform::errors::PreconditionNotMet(
                "In fusion group, Same Node id must have same date type"));
      }
    }
  }
  return dtypes;
}

248
// we get the parameter list code for the expression information
249 250
std::string CodeGenerator::EmitParameters(
    const std::set<int>& input_ids, const std::set<int>& output_ids,
251
    const std::set<int>& intermediate_ids,
252
    const std::unordered_map<int, std::string>& dtypes) const {
253 254
  std::stringstream ret;
  ret << "int N, ";
255 256 257

  // If a id is in the input and output list at the same time, then remove it
  // from the input list.
258 259
  for (auto id : input_ids) {
    if (output_ids.find(id) == output_ids.end()) {
260 261
      ret << "const " << dtypes.at(id) << "* __restrict__ " << ArgName(id)
          << ", ";
262 263 264
    }
  }

265
  size_t index = 0;
266
  std::vector<std::string> output_args;
267
  for (auto id : output_ids) {
268
    if (intermediate_ids.find(id) == intermediate_ids.end()) {
269 270 271 272 273 274 275 276
      std::string args_str = dtypes.at(id) + "* " + ArgName(id);
      output_args.push_back(args_str);
    }
  }
  for (auto args : output_args) {
    ret << args;
    if (index != output_args.size() - 1) {
      ret << ", ";
277
    }
278
    index++;
279 280
  }
  return ret.str();
281
}
282 283

std::string CodeGenerator::EmitComputeBody(
284 285
    const std::vector<OperationExpression>& expressions,
    const std::set<int>& input_ids, const std::set<int>& output_ids,
286
    const std::set<int>& intermediate_ids,
287
    const std::unordered_map<int, std::string>& dtypes) const {
288 289
  std::ostringstream compute;
  std::unordered_set<int> used;
290
  for (size_t i = 0; i < expressions.size(); i++) {
291
    VLOG(3) << DebugString(expressions[i]);
292
    compute << expressions[i].GetExpression(&used);
293
  }
294 295 296 297 298 299

  // Load input to temporal variables.
  std::ostringstream load;
  for (auto id : input_ids) {
    if (output_ids.find(id) == output_ids.end() &&
        used.find(id) != used.end()) {
300 301
      load << dtypes.at(id) << " " << TmpName(id) << " = "
           << "__ldg(&" << VarName(id) << ")"
302
           << ";";
303 304 305 306 307
    }
  }
  // Store temporal variables to memory.
  std::ostringstream store;
  for (auto id : output_ids) {
308 309 310
    if (intermediate_ids.find(id) == intermediate_ids.end()) {
      store << VarName(id) << " = " << TmpName(id) << ";";
    }
311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
  }

  return load.str() + compute.str() + store.str();
}

std::unordered_map<std::string, int> CodeGenerator::EncodeVarNodes(
    SubGraph* subgraph) {
  const auto& input_var_nodes = subgraph->GetInputVarNodes();
  const auto& output_var_nodes = subgraph->GetOutputVarNodes();

  int id = 0;
  std::unordered_map<std::string, int> var_ids;
  // Numbering input vars.
  for (auto* in : input_var_nodes) {
    VLOG(3) << "Encoding input names:" << in->Name() << ", id:" << id;
    if (var_ids.find(in->Name()) == var_ids.end()) {
      var_ids[in->Name()] = id++;
    }
  }
330

331 332 333 334 335 336 337 338
  // Encoding output vars.
  for (auto* out : output_var_nodes) {
    VLOG(3) << "Ecoding output names:" << out->Name() << ", id:" << id;
    if (var_ids.find(out->Name()) == var_ids.end()) {
      var_ids[out->Name()] = id++;
    }
  }
  return var_ids;
339 340 341
}

}  // namespace fusion_group
342 343 344
}  // namespace ir
}  // namespace framework
}  // namespace paddle