diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index dcd9bddb0612c8847e8d1edd13609b3b97a30d64..1b76aae24646051f700a6c8af03674e17b309573 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1176,6 +1176,25 @@ struct AddNOpTranscriber : public OpTranscriber { } }; +struct TrilAndTriuOpTranscriber : public OpTranscriber { + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower")); + std::string target_op_name = ""; + if (lower) { + target_op_name = "pd.tril"; + } else { + target_op_name = "pd.triu"; + } + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op tril_triu should have corresponding OpInfo pd.tril or pd.triu."); + } + + return op_info; + } +}; + ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, TranslationContext* param_map, const OpDesc& op_desc, @@ -1473,6 +1492,7 @@ OpTranslator::OpTranslator() { special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); + special_handlers["tril_triu"] = TrilAndTriuOpTranscriber(); // special handler for elementwise ops with axis != -1 // note(lyk): maybe we should do this by a pass, which seems more reasonable