From 17f21508e7199b8f1e294a31260a4c2b8612b483 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Wed, 16 Aug 2023 14:39:42 +0800 Subject: [PATCH] [IR] add TrilAndTriuOpTranscriber for ir translator (#56308) * add TrilAndTriuOpTranscriber for ir translator * refine --- .../ir_adaptor/translator/op_translator.cc | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index dcd9bddb061..1b76aae2464 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1176,6 +1176,25 @@ struct AddNOpTranscriber : public OpTranscriber { } }; +struct TrilAndTriuOpTranscriber : public OpTranscriber { + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower")); + std::string target_op_name = ""; + if (lower) { + target_op_name = "pd.tril"; + } else { + target_op_name = "pd.triu"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op tril_triu should have corresponding OpInfo pd.tril or pd.triu."); + } + + return op_info; + } +}; + ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, @@ -1473,6 +1492,7 @@ OpTranslator::OpTranslator() { special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); + special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); // special handler for elementwise ops with axis != -1 // note(lyk): maybe we should do this by a pass, which seems more reasonable -- GitLab