diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index 04cc99075bd92a15d5e354383edb16d310494beb..ef53f1d496a3dd0cf987b32de8c07d039883a649 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -1379,6 +1379,19 @@ struct ElementwiseTranscriber : public OpTranscriber { } }; +struct GradAddOpTranscriber : public ElementwiseTranscriber { + ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override { + const std::string& target_op_name = "pd.add"; + const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name); + if (!op_info) { + IR_THROW( + "Op assign_value should have corresponding OpInfo pd.assign_value_"); + } + + return op_info; + } +}; + struct ElementwiseGradTranscriber : public OpTranscriber { void RecordOpResultMapping(ir::IrContext* ctx, TranslationContext* param_map, @@ -1450,6 +1463,7 @@ OpTranslator::OpTranslator() { special_handlers["feed"] = FeedOpTranscriber(); special_handlers["data"] = DataOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber(); + special_handlers["grad_add"] = GradAddOpTranscriber(); special_handlers["increment"] = IncrementOpTranscriber(); special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index d9c4bd1fbaf093c19320f6767f03ff3c1688b795..0f7be38a74386c0cc7c96535429eec65801bdeed 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -1230,6 +1230,10 @@ {pre_nms_top_n : pre_nms_topN, post_nms_top_n : post_nms_topN} - op : grad_add + inputs : + {x : X, y : Y} + outputs : + {out : Out} extra : attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool use_quantizer = false, float Scale_x = 1.0f, float Scale_y = 1.0f, float Scale_out = 1.0f] diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index ea4837eb4789aa9df247f19bbe28ac1f9e3ca247..4cc11fd9c28189c72c25614d3971b622ffd7e3cd 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -274,5 +274,31 @@ class TestIndexPutOpTranscriber(unittest.TestCase): _ = ir.translate_to_new_ir(main_program.desc) +class TestGradAddOpTranscriber(unittest.TestCase): + def test_op(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + x_data = np.random.rand(100, 2, 3) + y_data = np.random.rand(100, 1, 1) + x = paddle.to_tensor(x_data, dtype='float32') + x.stop_gradient = False + y = paddle.to_tensor(y_data, dtype='float32') + + helper = LayerHelper('grad_add') + out = helper.create_variable_for_type_inference("float") + helper.append_op( + type="grad_add", + inputs={"X": x, "Y": y}, + outputs={"Out": out}, + attrs={"axis": -1}, + ) + + _ = ir.translate_to_new_ir(main_program.desc) + + if __name__ == "__main__": unittest.main()