未验证 提交 1692af99 编写于 作者: K kangguangli 提交者: GitHub

[NewIR] add_arg_mapping_for_fetch (#56752)

* add_arg_mapping_for_fetch

* fix

* fix
上级 ade51aa5
...@@ -128,6 +128,7 @@ def OpNameNormalizerInitialization( ...@@ -128,6 +128,7 @@ def OpNameNormalizerInitialization(
# special mapping list # special mapping list
op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD" op_arg_name_mappings["set_value_grad"]["values_grad"] = "ValueTensor@GRAD"
op_arg_name_mappings["fetch"] = {"x": "X"}
op_name_normailzer_template = env.get_template("op_compat_info.cc.j2") op_name_normailzer_template = env.get_template("op_compat_info.cc.j2")
with open(output_source_file, 'wt') as f: with open(output_source_file, 'wt') as f:
......
...@@ -1699,6 +1699,7 @@ OpTranslator::OpTranslator() { ...@@ -1699,6 +1699,7 @@ OpTranslator::OpTranslator() {
special_handlers["cast"] = CastOpTranscriber(); special_handlers["cast"] = CastOpTranscriber();
special_handlers["feed"] = FeedOpTranscriber(); special_handlers["feed"] = FeedOpTranscriber();
special_handlers["data"] = DataOpTranscriber(); special_handlers["data"] = DataOpTranscriber();
special_handlers["fetch"] = FetchOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["fill_constant"] = FillConstantTranscriber(); special_handlers["fill_constant"] = FillConstantTranscriber();
special_handlers["grad_add"] = GradAddOpTranscriber(); special_handlers["grad_add"] = GradAddOpTranscriber();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册