// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/cinn/common/common.h" #include "paddle/cinn/ir/buffer.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_printer.h" #include "paddle/cinn/ir/operation.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/runtime/intrinsic.h" namespace cinn { namespace lang { using ir::Expr; /** * Placeholder * @tparam T */ template class Placeholder { public: Placeholder(const std::string &name, const std::vector &shape); Placeholder(const std::string &name, const std::vector &shape); //! Get a slice. // @{ Expr operator()(Expr a) const { return Call({a}); } Expr operator()(Expr a, Expr b) const { return Call({a, b}); } Expr operator()(Expr a, Expr b, Expr c) const { return Call({a, b, c}); } Expr operator()(Expr a, Expr b, Expr c, Expr d) const { return Call({a, b, c, d}); } Expr operator()(const std::vector &indices) const; // @} Type type() const { return tensor_->type(); } operator ir::Tensor() { return tensor_; } operator ir::Expr() { return Expr(tensor_); } ir::Tensor &operator->() { return tensor_; } const ir::Tensor &operator->() const { return tensor_; } ir::Tensor tensor() const { return tensor_; } private: Expr Call(const std::vector &indices) const; void Init(const std::string &name, const std::vector &shape); ir::Tensor tensor_; }; template Expr Placeholder::operator()(const std::vector &indices) const { return tensor_(indices); } template Expr Placeholder::Call(const std::vector &indices) const { return tensor_(indices); } template Placeholder::Placeholder(const std::string &name, const std::vector &shape) { std::vector _shape; for (int v : shape) _shape.push_back(Expr(v)); Init(name, _shape); } template Placeholder::Placeholder(const std::string &name, const std::vector &shape) { Init(name, shape); } ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); ir::Tensor CreatePlaceHolder(const std::vector &shape, Type type, const std::string &name); /// ------- details ------- template void Placeholder::Init(const std::string &name, const std::vector &shape) { ir::Var buffer_ptr(Context::Global().NewName("buffer")); buffer_ptr->set_type(type_of()); std::vector strides(shape.size(), Expr(1)); Expr offset(0); std::vector axis; for (int i = 0; i < shape.size(); i++) axis.emplace_back(common::axis_name(i)); auto op = ir::PlaceholderOp::Make(name, shape, type_of()); tensor_ = ir::Tensor(name, type_of(), shape, shape, op, {}); Buffer buffer(tensor_->type()); tensor_->Bind(buffer); } } // namespace lang } // namespace cinn