// Copyright (c) 2023 CINN Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "paddle/cinn/common/common.h" #include "paddle/cinn/common/macros.h" #include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/op.h" #include "paddle/cinn/hlir/framework/op_strategy.h" #include "paddle/cinn/hlir/op/op_util.h" #include "paddle/cinn/hlir/pe/elementwise.h" #include "paddle/cinn/hlir/pe/ir_schedule_pe.h" #include "paddle/cinn/ir/ir.h" #include "paddle/cinn/ir/ir_base.h" #include "paddle/cinn/ir/tensor.h" #include "paddle/cinn/lang/builtin.h" #include "paddle/cinn/lang/compute.h" namespace cinn { namespace hlir { namespace op { using common::CINNValue; using common::CINNValuePack; std::shared_ptr StrategyForAssertTrue(const framework::NodeAttr &attrs, const std::vector &inputs, const std::vector &out_type, const std::vector> &output_shapes, const Target &target) { framework::CINNCompute assert_true_compute([=](lang::Args args, lang::RetValue *ret) { CHECK(!args.empty()) << "The input argument of assert_true is empty! Please check."; CINNValuePack pack_args = args[0]; CHECK_GE(pack_args.size(), 2U) << "Two input tensors are required for the computation of assert_true."; Expr a_expr = pack_args[0]; Expr b_expr = pack_args[1]; ir::Tensor a = a_expr.as_tensor_ref(); ir::Tensor b = b_expr.as_tensor_ref(); std::string tensor_name = "assert_true_out"; auto out = pe::Identity(b, tensor_name).front(); auto stages = CreateStages({out}); std::vector res{CINNValue(out), CINNValue(stages)}; *ret = CINNValuePack{res}; }); auto strategy = std::make_shared(); strategy->AddImpl( assert_true_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.assert_true.x86", 1); return strategy; } std::vector InferShapeForAssertTrue(const std::vector &inputs_shape, const framework::AttrMapType &attrs) { return inputs_shape; } std::vector InferDtypeForAssertTrue(const std::vector &inputs_type, const framework::AttrMapType &attrs) { return inputs_type; } } // namespace op } // namespace hlir } // namespace cinn CINN_REGISTER_HELPER(assert_true_ops) { CINN_REGISTER_OP(assert_true) .describe("AssertTrue") .set_num_inputs(1) .set_num_outputs(1) .set_attr("CINNStrategy", cinn::hlir::op::StrategyForAssertTrue) .set_attr("infershape", MakeOpFunction(cinn::hlir::op::InferShapeForAssertTrue)) .set_attr("inferdtype", MakeOpFunction(cinn::hlir::op::InferDtypeForAssertTrue)) .set_attr("OpPattern", cinn::hlir::framework::OpPatternKind::kNonFusible) .set_support_level(4); return true; }