/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "base/expr_builder.h" namespace akg { ktvm::Array UTExprBuilder::CreateShape(const std::vector &shapes) { ktvm::Array res; for (int32_t shape : shapes) { ktvm::Integer imm = ktvm::IntImm::make(ktvm::Int(32), shape); res.push_back(imm); } return res; } ktvm::Var UTExprBuilder::CreateVar(const std::string &name) { return ktvm::Var(name); } ktvm::Array UTExprBuilder::CreateVars(const std::vector &names) { ktvm::Array vars; for (const std::string &name : names) { vars.push_back(std::move(CreateVar(name))); } return vars; } ktvm::Operation UTExprBuilder::PlaceholderOpNode( const std::string &name, const std::vector &shapes, ktvm::DataType dtype) { ktvm::Array expr_shapes = CreateShape(shapes); return ktvm::PlaceholderOpNode::make(name, expr_shapes, dtype); } ktvm::Expr UTExprBuilder::TensorElement( const std::string &name, const std::vector &shapes, const std::vector &axis_names, ktvm::DataType dtype) { return ktvm::ir::Call::make( dtype, // type name, // name CreateVars(axis_names), // args ktvm::ir::Call::Halide, // call_type PlaceholderOpNode(name, shapes, dtype), // func, 0); // value_index } UTTensorElementHelper::UTTensorElementHelper(const std::vector &shapes, const std::string &axis_name_prefix) : shapes_(shapes), axis_name_prefix_(axis_name_prefix) { std::stringstream ss; for (size_t i = 0; i < shapes_.size(); i++) { ss << axis_name_prefix_ << i; axis_names_.push_back(ss.str()); ss.str(""); } } ktvm::Expr UTTensorElementHelper::Elem(const std::string &name, uint32_t dim, ktvm::DataType dtype) const { uint32_t start = shapes_.size() - dim; return UTExprBuilder::TensorElement( name, std::vector(shapes_.begin() + start, shapes_.end()), std::vector(axis_names_.begin() + start, axis_names_.end()), dtype); } } // namespace akg