ir_copy.cc 14.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/cinn/ir/utils/ir_copy.h"
16 17 18 19 20 21 22 23 24

#include <map>
#include <memory>
#include <string>
#include <vector>

#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/module.h"
25 26 27
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
28 29 30 31 32

namespace cinn {
namespace optim {
using namespace ir;  // NOLINT

33
struct IRCopyVisitor : public ir::IRVisitorRequireReImpl<Expr> {
34 35 36 37
  // Use maps to unify all the copied tensors and buffers.
  std::map<std::string, ir::_Tensor_*> tensor_map;
  std::map<std::string, ir::_Buffer_*> buffer_map;

38 39 40
  Expr Visit(const Expr* op) override {
    return IRVisitorRequireReImpl::Visit(op);
  }
41 42 43 44

 protected:
  // The methods of ir nodes follows the order defined in node.h

45 46 47 48 49 50 51 52 53 54 55 56
  Expr Visit(const ir::IntImm* op) override {
    return Expr(make_shared<IntImm>(op->type(), op->value));
  }
  Expr Visit(const ir::UIntImm* op) override {
    return Expr(make_shared<UIntImm>(op->type(), op->value));
  }
  Expr Visit(const ir::FloatImm* op) override {
    return Expr(make_shared<FloatImm>(op->type(), op->value));
  }
  Expr Visit(const ir::StringImm* op) override {
    return Expr(common::make_shared<StringImm>(op->value));
  }
57 58 59 60 61 62 63

  Expr Visit(const ir::Cast* op) override {
    auto v = Visit(&op->v());
    return Cast::Make(op->type(), v);
  }

  Expr Visit(const Select* op) override {
64 65
    auto condition = Visit(&op->condition);
    auto true_value = Visit(&op->true_value);
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    auto false_value = Visit(&op->false_value);
    return Select::Make(condition, true_value, false_value);
  }

  Expr Visit(const IfThenElse* op) override {
    auto condition = Visit(&op->condition);
    auto true_case = Visit(&op->true_case);
    Expr false_case;
    if (op->false_case.defined()) false_case = Visit(&op->false_case);
    return IfThenElse::Make(condition, true_case, false_case);
  }

  Expr Visit(const Block* op) override {
    std::vector<Expr> stmts;
    for (auto& s : op->stmts) {
      stmts.push_back(Visit(&s));
    }
    return Block::Make(stmts);
  }

  Expr Visit(const Call* op) override {
87
    auto read_args = Visit(op->read_args);
88
    auto write_args = Visit(op->write_args);
89 90 91 92 93 94 95 96
    return Call::Make(op->type(),
                      op->name,
                      read_args,
                      write_args,
                      op->call_type,
                      FunctionRef(),
                      0,
                      op->attrs);
97 98 99 100 101
  }

  Expr Visit(const _Var_* op) override {
    auto* n = make_shared<_Var_>();

102
    n->name = op->name;
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    n->is_reduce_axis = op->is_reduce_axis;
    n->set_type(op->type());

    if (op->lower_bound.defined()) {
      n->lower_bound = Visit(&op->lower_bound);
    }
    if (op->upper_bound.defined()) {
      n->upper_bound = Visit(&op->upper_bound);
    }

    return Expr(n);
  }

  Expr Visit(const Load* op) override {
    auto tensor = Visit(&op->tensor);
    std::vector<Expr> indices;
    for (auto& idx : op->indices) {
      indices.push_back(Visit(&idx));
    }
    return Load::Make(tensor, indices);
  }

  Expr Visit(const Store* op) override {
    auto tensor = Visit(&op->tensor);
127
    auto value = Visit(&op->value);
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
    std::vector<Expr> indices;
    for (auto& idx : op->indices) indices.push_back(Visit(&idx));

    return Store::Make(tensor, value, indices);
  }

  Expr Visit(const Alloc* op) override {
    auto extents = Visit(op->extents);
    Expr condition;
    Expr body;
    if (op->condition.defined()) condition = Visit(&op->condition);
    if (op->body.defined()) body = Visit(&op->body);

    return Alloc::Make(op->destination, op->type(), extents, condition, body);
  }

  Expr Visit(const Free* op) override { return Free::Make(op->destination); }

  Expr Visit(const _Buffer_* op) override {
    if (buffer_map.count(op->name)) {
      return buffer_map[op->name];
    }

151 152 153 154
    auto shape = Visit(op->shape);
    auto strides = Visit(op->strides);
    auto name = op->name;
    auto scope = op->scope;
155
    int data_alignment = op->data_alignment;
156 157 158 159 160 161 162 163 164
    auto elem_offset = Visit(&op->elem_offset);
    int offset_factor = op->offset_factor;
    Target target = op->target;

    auto new_node = _Buffer_::Make(name, shape);
    new_node->strides = strides;
    new_node->dtype = op->dtype;  // copy data element's type.
    new_node->name = name;
    new_node->scope = scope;
165
    new_node->data_alignment = data_alignment;
166 167 168 169
    new_node->elem_offset = elem_offset;
    new_node->offset_factor = offset_factor;
    new_node->target = target;
    new_node->memory_type = op->memory_type;
170 171 172 173 174 175 176 177 178 179 180 181 182
    new_node->set_type(op->type());
    op->CopyMeta(new_node.As<ir::_Buffer_>());

    buffer_map[op->name] = new_node->self();

    return Expr(ir::Buffer(new_node));
  }

  Expr Visit(const _Tensor_* op) override {
    if (tensor_map.count(op->name)) {
      return tensor_map[op->name];
    }

183 184
    auto shape = Visit(op->shape);
    auto domain = Visit(op->domain);
185 186 187
    auto buffer_expr = Expr(op->buffer);
    // TODO(Superjomn) copy the operation.
    auto operaion = op->operation;
188 189
    auto name = op->name;
    auto tensor = make_shared<_Tensor_>();
190 191

    if (buffer_expr.defined()) {
192
      auto buffer = Visit(&buffer_expr);
193 194
      tensor->buffer = buffer.as_buffer_ref();
    }
195 196
    tensor->domain = domain;
    tensor->shape = shape;
197
    tensor->reduce_axis = op->reduce_axis;
198 199
    tensor->operation = operaion;
    tensor->name = name;
200 201 202 203 204 205 206 207 208 209
    tensor->set_type(op->type());
    tensor->axis_ = op->axis_;

    tensor_map[tensor->name] = tensor;

    return tensor;
  }

  Expr Visit(const For* op) override {
    auto extent = Visit(&op->extent);
210 211
    auto min = Visit(&op->min);
    auto body = Visit(&op->body);
212

213 214 215 216 217 218 219 220
    return ir::For::Make(op->loop_var,
                         min,
                         extent,
                         op->for_type(),
                         op->device_api,
                         body,
                         op->vectorize_info(),
                         op->bind_info());
221 222 223
  }

  Expr Visit(const ir::PolyFor* op) override {
224
    auto init = Visit(&op->init);
225
    auto condition = Visit(&op->condition);
226 227 228
    auto inc = Visit(&op->inc);
    auto body = Visit(&op->body);
    auto expr = PolyFor::Make(op->iterator,
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
                              init,
                              condition,
                              inc,
                              op->for_type(),
                              op->device_api,
                              body,
                              op->vectorize_info(),
                              op->bind_info());
    return expr;
  }

  Expr Visit(const ir::_Module_* op) override {
    std::vector<Expr> buffers;
    std::vector<Expr> functions;
    std::vector<Expr> submodules;

    for (auto& expr : op->buffers) {
      buffers.push_back(Visit(&expr));
    }

    for (auto& expr : op->functions) {
      functions.push_back(Visit(&expr));
    }

    for (auto& expr : op->submodules) {
      submodules.push_back(Visit(&expr));
    }

257 258 259
    auto res = ir::_Module_::Make(op->name, op->target);
    res->buffers = buffers;
    res->functions = functions;
260 261 262 263 264 265 266 267
    res->submodules = submodules;

    return Expr(res);
  }

  Expr Visit(const _LoweredFunc_* op) override {
    auto func = make_shared<_LoweredFunc_>();

268 269 270
    func->name = op->name;
    func->args = op->args;
    func->body = Visit(&op->body);
271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297
    func->temp_bufs = op->temp_bufs;

    func->device_api = op->device_api;

    func->cuda_axis_info = op->cuda_axis_info;

    std::vector<Expr> alloc_output_buffer_exprs;
    std::vector<Expr> dealloc_output_buffer_exprs;
    std::vector<Expr> buffer_data_cast_exprs;
    std::vector<Expr> argument_prepare_exprs;

#define COPY_ADD_FIELD(field__)      \
  for (auto& expr : op->field__) {   \
    field__.push_back(Visit(&expr)); \
  }                                  \
  func->field__ = std::move(field__);

    COPY_ADD_FIELD(alloc_output_buffer_exprs);
    COPY_ADD_FIELD(dealloc_output_buffer_exprs);
    COPY_ADD_FIELD(buffer_data_cast_exprs);
    COPY_ADD_FIELD(argument_prepare_exprs);

    return func;
  }

  Expr Visit(const Let* op) override {
    auto value = Visit(&op->symbol);
298
    auto body = Visit(&op->body);
299 300 301 302 303 304 305

    return Let::Make(value, body);
  }

  Expr Visit(const Reduce* op) override {
    auto init = Visit(&op->init);
    auto body = Visit(&op->body);
306 307
    std::vector<Var> reduce_axis(op->reduce_axis.begin(),
                                 op->reduce_axis.end());
308 309 310 311
    return Reduce::Make(op->reduce_type, init, body, reduce_axis);
  }

  Expr Visit(const Ramp* op) override {
312
    auto base = Visit(&op->base);
313
    auto stride = Visit(&op->stride);
314
    int lanes = op->lanes;
315 316 317 318 319
    return Ramp::Make(base, stride, lanes);
  }

  Expr Visit(const Broadcast* op) override {
    auto value = Visit(&op->value);
320
    int lanes = op->lanes;
321 322 323
    CHECK(value.defined());
    CHECK(value.type().valid());

324
    auto* n = make_shared<Broadcast>();
325 326 327 328 329 330 331 332 333 334 335 336
    n->value = value;
    n->lanes = lanes;
    return Expr(n);
  }

  Expr Visit(const FracOp* op) override {
    auto a = Visit(&op->a());
    auto b = Visit(&op->b());
    CHECK(a.defined());
    CHECK(b.defined());

    auto* n = make_shared<FracOp>();
337 338
    n->a() = a;
    n->b() = b;
339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363
    return Expr(n);
  }

  Expr Visit(const Product* op) override {
    std::vector<Expr> operands;
    for (auto& v : op->operands()) {
      operands.push_back(Visit(&v));
    }
    return Product::Make(operands);
  }

  Expr Visit(const Sum* op) override {
    std::vector<Expr> operands;
    for (auto& v : op->operands()) {
      operands.push_back(Visit(&v));
    }
    return Sum::Make(operands);
  }

  Expr Visit(const ir::PrimitiveNode* op) override {
    std::vector<std::vector<Expr>> arguments;
    for (auto& args : op->arguments) {
      arguments.push_back(Visit(args));
    }

364 365 366
    auto n = common::make_shared<ir::PrimitiveNode>();
    n->name = op->name;
    n->attrs = op->attrs;  // attrs are PODs
367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
    n->arguments = arguments;
    return Expr(n);
  }

  Expr Visit(const ir::_BufferRange_* op) {
    std::vector<Var> ranges;
    for (auto& range_var : op->ranges) {
      auto* var = range_var.As<_Var_>();
      ranges.push_back(Visit(var));
    }
    return ir::_BufferRange_::Make(Visit(&op->buffer), ranges);
  }

  Expr Visit(const ir::ScheduleBlock* op) {
    std::vector<Var> iter_vars;
    for (auto iter_var : op->iter_vars) {
      auto* var = iter_var.As<_Var_>();
      CHECK(var);
      iter_vars.push_back(Visit(var));
    }
    std::vector<Expr> read_buffers;
    for (auto buffer_range : op->read_buffers) {
      read_buffers.push_back(Visit(&buffer_range));
    }
    std::vector<Expr> write_buffers;
    for (auto buffer_range : op->write_buffers) {
      write_buffers.push_back(Visit(&buffer_range));
    }
395 396
    Expr res = ir::ScheduleBlock::Make(
        iter_vars, read_buffers, write_buffers, op->name, Visit(&op->body));
397 398 399 400 401 402 403 404 405
    res.As<ScheduleBlock>()->attrs = op->attrs;
    return res;
  }

  Expr Visit(const ir::ScheduleBlockRealize* op) {
    std::vector<Expr> iter_values;
    for (auto iter_value : op->iter_values) {
      iter_values.push_back(Visit(&iter_value));
    }
406 407
    return ir::ScheduleBlockRealize::Make(iter_values,
                                          Visit(&op->schedule_block));
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423
  }

#define __(x__) Expr Visit(const ir::intrinsics::x__* op);
  INTRINSIC_KIND_FOR_EACH(__)
#undef __

  Expr Visit(const ir::IntrinsicOp* op) override {
    switch (op->getKind()) {
#define __(x__)                   \
  case ir::IntrinsicKind::k##x__: \
    return Visit(llvm::dyn_cast<ir::intrinsics::x__>(op));
      INTRINSIC_KIND_FOR_EACH(__)
#undef __
    }
  }

424 425 426 427 428
#define OP_BINARY_HANDLE(op__)                        \
  Expr Visit(const ir::op__* op) override {           \
    auto a = IRVisitorRequireReImpl::Visit(&op->a()); \
    auto b = IRVisitorRequireReImpl::Visit(&op->b()); \
    return op__::Make(a, b);                          \
429 430 431 432
  }
  NODETY_BINARY_OP_FOR_EACH(OP_BINARY_HANDLE)
#undef OP_BINARY_HANDLE

433 434 435 436
#define OP_UNARY_HANDLE(op__)                         \
  Expr Visit(const op__* op) override {               \
    auto v = IRVisitorRequireReImpl::Visit(&op->v()); \
    return op__::Make(v);                             \
437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
  }
  NODETY_UNARY_OP_FOR_EACH(OP_UNARY_HANDLE)
#undef OP_UNARY_HANDLE

  std::vector<Expr> Visit(const std::vector<Expr>& vs) {
    std::vector<Expr> copied;
    for (auto& e : vs) {
      copied.push_back(Visit(&e));
    }
    return copied;
  }
};

Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferGetDataHandle* op) {
  return intrinsics::BufferGetDataHandle::Make(Visit(&op->buffer));
}
Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferGetDataConstHandle* op) {
  return intrinsics::BufferGetDataConstHandle::Make(Visit(&op->buffer));
}
Expr IRCopyVisitor::Visit(const ir::intrinsics::PodValueToX* op) {
457 458
  return intrinsics::PodValueToX::Make(Visit(&op->pod_value_ptr),
                                       op->GetOutputType(0));
459 460 461 462
}
Expr IRCopyVisitor::Visit(const ir::intrinsics::BufferCreate* op) {
  return intrinsics::BufferCreate::Make(Visit(&op->buffer));
}
463 464 465
Expr IRCopyVisitor::Visit(const ir::intrinsics::GetAddr* op) {
  return intrinsics::GetAddr::Make(Visit(&op->data));
}
466 467 468 469 470 471 472 473
Expr IRCopyVisitor::Visit(const ir::intrinsics::ArgsConstruct* op) {
  llvm::SmallVector<Expr, 7> args;
  for (auto& arg : op->args) {
    args.push_back(Visit(&arg));
  }
  return intrinsics::ArgsConstruct::Make(op->var, args);
}
Expr IRCopyVisitor::Visit(const ir::intrinsics::BuiltinIntrin* op) {
474 475
  return intrinsics::BuiltinIntrin::Make(
      op->name, op->args, op->id, op->arg_nums, op->type());
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491
}

Expr IRCopy(Expr x) {
  IRCopyVisitor visitor;
  auto copied = visitor.Visit(&x);
  return copied;
}

std::vector<Expr> IRCopy(const std::vector<Expr>& x) {
  std::vector<Expr> res;
  for (auto& i : x) {
    res.emplace_back(IRCopy(i));
  }
  return res;
}

492 493 494
ir::ModuleExpr IRCopy(const ir::ModuleExpr& x) {
  return ir::ModuleExpr(IRCopy(x.GetExprs()));
}
495 496

ir::LoweredFunc IRCopy(const ir::LoweredFunc& x) {
497
  ir::Expr copy_func_expr = IRCopy(static_cast<ir::Expr>(x));
498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
  ir::_LoweredFunc_* copy_func_ptr = copy_func_expr.As<ir::_LoweredFunc_>();
  return ir::LoweredFunc(copy_func_ptr);
}

// TODO(zhhsplendid): make IRCopy of std::vector a template function
std::vector<ir::LoweredFunc> IRCopy(const std::vector<ir::LoweredFunc>& x) {
  std::vector<ir::LoweredFunc> res;
  for (const auto& i : x) {
    res.emplace_back(IRCopy(i));
  }
  return res;
}

}  // namespace optim
}  // namespace cinn