/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "base/expr_builder.h" namespace akg { air::Expr UTExprBuilder::IntImm(int64_t value, air::DataType dtype) { return air::IntImm::make(dtype, value); } air::Expr UTExprBuilder::UIntImm(uint64_t value, air::DataType dtype) { return air::ir::UIntImm::make(dtype, value); } air::Expr UTExprBuilder::BoolImm(bool value) { return air::ir::UIntImm::make(air::Bool(), value ? 1 : 0); } air::Array UTExprBuilder::CreateShape(const std::vector &shapes) { air::Array res; for (int32_t shape : shapes) { air::Integer imm = air::IntImm::make(air::Int(32), shape); res.push_back(imm); } return res; } air::Var UTExprBuilder::CreateVar(const std::string &name) { return air::Var(name); } air::Array UTExprBuilder::CreateVars(const std::vector &names) { air::Array vars; for (const std::string &name : names) { vars.push_back(std::move(CreateVar(name))); } return vars; } air::Region UTExprBuilder::CreateRegion(const std::vector &shapes) { air::Region region; for (int32_t shape : shapes) { region.push_back(CreateRange(0, shape)); } return region; } air::Region UTExprBuilder::CreateRegion(const air::Array &shapes) { air::Region region; for (const air::Expr &shape : shapes) { region.push_back(air::Range::make_by_min_extent(IntImm(0), shape)); } return region; } air::Range UTExprBuilder::CreateRange(int32_t min, int32_t max) { air::Integer imm_min = air::IntImm::make(air::Int(32), min); air::Integer imm_max = air::IntImm::make(air::Int(32), max); return air::Range(std::move(imm_min), std::move(imm_max)); } air::Operation UTExprBuilder::PlaceholderOpNode( const std::string &name, const std::vector &shapes, air::DataType dtype) { air::Array expr_shapes = CreateShape(shapes); return air::PlaceholderOpNode::make(name, expr_shapes, dtype); } air::Expr UTExprBuilder::TensorElement( const std::string &name, const std::vector &shapes, const std::vector &axis_names, air::DataType dtype) { return air::ir::Call::make( dtype, // type name, // name CreateVars(axis_names), // args air::ir::Call::Halide, // call_type PlaceholderOpNode(name, shapes, dtype), // func, 0); // value_index } air::Expr UTExprBuilder::ElementOf( const air::Operation &op, const std::vector &axis_names) { if (op->template IsInstance()) { return ElementOfPlaceholderOp(op, axis_names); } else { CHECK(false); return air::ir::Any::make(); } } air::Expr UTExprBuilder::ElementOfPlaceholderOp( const air::Operation &op, const std::vector &axis_names) { const air::PlaceholderOpNode *node = op.as(); CHECK(node); return air::ir::Call::make( node->dtype, node->name, CreateVars(axis_names), air::ir::Call::Halide, op, 0); } air::Expr UTExprBuilder::CreateCall( const air::ir::FunctionRef func, air::Array args, air::ir::Call::CallType call_type, int value_index) { air::DataType type = air::Float(16); const air::OperationNode *node_op = func.as(); CHECK(node_op); std::string name = node_op->name; const air::PlaceholderOpNode *node_placeholder = func.as(); if (node_placeholder != nullptr) { type = node_placeholder->dtype; } return air::ir::Call::make(type, name, args, call_type, func, value_index); } air::Tensor UTExprBuilder::CreateTensorByPlaceholder(const air::Operation op) { const air::PlaceholderOpNode *node = op.as(); CHECK(node); return air::TensorNode::make( node->shape, node->dtype, op, 0); } UTTensorElementHelper::UTTensorElementHelper(const std::vector &shapes, const std::string &axis_name_prefix) : shapes_(shapes), axis_name_prefix_(axis_name_prefix) { std::stringstream ss; for (size_t i = 0; i < shapes_.size(); i++) { ss << axis_name_prefix_ << i; axis_names_.push_back(ss.str()); ss.str(""); } } air::Expr UTTensorElementHelper::Elem(const std::string &name, uint32_t dim, air::DataType dtype) const { uint32_t start = shapes_.size() - dim; return UTExprBuilder::TensorElement( name, std::vector(shapes_.begin() + start, shapes_.end()), std::vector(axis_names_.begin() + start, axis_names_.end()), dtype); } } // namespace akg