// Copyright (c) 2021 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/cinn/ir/operation.h" #include #include "paddle/cinn/common/common.h" namespace cinn { namespace ir { Operation PlaceholderOp::Make(const std::string &name, const std::vector &shape, Type dtype) { auto n = make_shared(); n->name = name; n->shape = shape; n->set_type(dtype); return Operation(n); } const char *PlaceholderOp::func_type() const { return "placeholder_op"; } const char *ComputeOp::func_type() const { return "compute_op"; } Operation ComputeOp::Make(const std::string &name, ComputeOp::handle_t handle, const std::vector &shape, const std::vector &domain, const std::vector &reduce_axis, const std::map &attrs, const std::string &tag) { auto n = make_shared(); n->name = name; n->producer_fn = handle; n->shape = domain; n->reduce_axis = reduce_axis; n->tag = tag; n->attrs = attrs; auto axis = common::GenDefaultAxis(domain.size()); std::vector _axis; for (auto &x : axis) _axis.push_back(x); n->body = {handle(_axis)}; n->reduce_axis = reduce_axis; return Operation(n); } Operation CallOp::Make(const std::string &call_target, Expr call_op) { auto n = make_shared(); n->call_expr = call_op; return Operation(n); } Operation PrecedingViewOp::Make(const Tensor &tensor, int preceding_axis) { return Operation(); } const char *PrecedingViewOp::func_type() const { return PrecedingViewOp::__func_type__; } const char *CallOp::func_type() const { return __func_type__; } const char *ComputeOp::__func_type__ = "compute_op"; const char *PlaceholderOp::__func_type__ = "placeholder_op"; const char *CallOp::__func_type__ = "call_op"; const std::string &CallOp::target() const { auto *call = call_expr.As(); CHECK(call); return call->name; } std::vector &CallOp::write_args() { auto *call = call_expr.As(); CHECK(call); return call->write_args; } std::vector &CallOp::read_args() { auto *call = call_expr.As(); CHECK(call); return call->read_args; } const std::vector &CallOp::write_args() const { auto *call = call_expr.As(); CHECK(call); return call->write_args; } const std::vector &CallOp::read_args() const { auto *call = call_expr.As(); CHECK(call); return call->read_args; } std::vector CallOp::args() const { std::vector args; auto &rargs = read_args(); auto &wargs = write_args(); args.insert(std::end(args), rargs.begin(), rargs.end()); args.insert(std::end(args), wargs.begin(), wargs.end()); return args; } const char *PrecedingViewOp::__func_type__ = "preceding_view_op"; const char *BufferShareOp::__func_type__ = "buffer_share_op"; const char *BufferShareOp::func_type() const { return __func_type__; } } // namespace ir } // namespace cinn