// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include #include #include "absl/types/variant.h" #include "paddle/cinn/common/cas.h" #include "paddle/cinn/common/cinn_value.h" #include "paddle/cinn/common/common.h" #include "paddle/cinn/common/context.h" #include "paddle/cinn/common/ir_util.h" #include "paddle/cinn/common/macros.h" #include "paddle/cinn/common/target.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" #include "paddle/cinn/hlir/op/op_util.h" #include "paddle/cinn/hlir/pe/elementwise.h" #include "paddle/cinn/hlir/pe/ir_schedule_pe.h" #include "paddle/cinn/hlir/pe/nn.h" #include "paddle/cinn/hlir/pe/schedule.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/ir_operators.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" #include "paddle/cinn/lang/packed_func.h" #include "paddle/cinn/poly/stage.h" #include "glog/logging.h" namespace cinn { namespace hlir { namespace op { using common::CINNValue; using common::CINNValuePack; std::shared_ptr StrategyForRandInt(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { framework::CINNCompute randint_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(attrs.attr_store.count("shape")); ir::Tensor shape_tensor; std::string tensor_name = "randint_out"; auto out = pe::Identity(shape_tensor, tensor_name).front(); auto stages = CreateStages({out}); std::vector res{CINNValue(out), CINNValue(stages)}; *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); strategy->AddImpl(randint_compute, GetInjectiveScheduleFunc(output_shapes, target), "strategy.randint.x86", 1); return strategy; } std::vector InferShapeForRandInt(const std::vector &inputs_shape, const framework::AttrMapType &attrs) { CHECK(attrs.count("shape")); auto shape = absl::get>(attrs.at("shape")); CHECK(!shape.empty()) << "shape attr is empty!"; return {shape}; } std::vector InferDtypeForRandInt(const std::vector &inputs_type, const framework::AttrMapType &attrs) { std::string dtype = "int32"; std::vector res{common::Str2Type(dtype)}; return res; } } // namespace op } // namespace hlir } // namespace cinn CINN_REGISTER_HELPER(randint_ops) { CINN_REGISTER_OP(randint) .describe("RandInt") .set_num_inputs(0) .set_num_outputs(1) .set_attr("CINNStrategy", cinn::hlir::op::StrategyForRandInt) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForRandInt)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForRandInt)) .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; }