未验证 提交 17f21508 编写于 作者: W wanghuancoder 提交者: GitHub

[IR] add TrilAndTriuOpTranscriber for ir translator (#56308)

* add TrilAndTriuOpTranscriber for ir translator

* refine
上级 58708a00
...@@ -1176,6 +1176,25 @@ struct AddNOpTranscriber : public OpTranscriber { ...@@ -1176,6 +1176,25 @@ struct AddNOpTranscriber : public OpTranscriber {
} }
}; };
struct TrilAndTriuOpTranscriber : public OpTranscriber {
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
bool lower = PADDLE_GET_CONST(bool, op_desc.GetAttr("lower"));
std::string target_op_name = "";
if (lower) {
target_op_name = "pd.tril";
} else {
target_op_name = "pd.triu";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW(
"Op tril_triu should have corresponding OpInfo pd.tril or pd.triu.");
}
return op_info;
}
};
ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
const OpDesc& op_desc, const OpDesc& op_desc,
...@@ -1473,6 +1492,7 @@ OpTranslator::OpTranslator() { ...@@ -1473,6 +1492,7 @@ OpTranslator::OpTranslator() {
special_handlers["shadow_output"] = ShadowOutputOpTranscriber(); special_handlers["shadow_output"] = ShadowOutputOpTranscriber();
special_handlers["split"] = SplitOpTranscriber(); special_handlers["split"] = SplitOpTranscriber();
special_handlers["sum"] = AddNOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber();
special_handlers["tril_triu"] = TrilAndTriuOpTranscriber();
// special handler for elementwise ops with axis != -1 // special handler for elementwise ops with axis != -1
// note(lyk): maybe we should do this by a pass, which seems more reasonable // note(lyk): maybe we should do this by a pass, which seems more reasonable
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册