diff --git a/paddle/fluid/ir_adaptor/translator/op_translator.cc b/paddle/fluid/ir_adaptor/translator/op_translator.cc index ee2f66692eda8657142dd35a94e9c5e8a21b323a..326413818622ed32fe8c64dd23139965c406b984 100644 --- a/paddle/fluid/ir_adaptor/translator/op_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/op_translator.cc @@ -60,12 +60,14 @@ using OpAttributeInfo = dialect::OpAttributeInfo; using OpAttributeInfoList = std::vector; using OpOutputInfo = dialect::OpOutputInfo; using OpOutputInfoList = std::vector; -using InputHandleFn = std::function; +using InputHandlerFn = std::function; +using AttributeHandlerFn = std::function; constexpr char kTargetDialectPrefix[] = "pd."; constexpr char kEmptyVarName[] = "@EMPTY@"; @@ -291,7 +293,12 @@ struct OpTranscriber { const OpOutputMapping& arg_to_idx); public: - virtual InputHandleFn GetSpecialInputHandlers(std::string input_name) { + virtual InputHandlerFn GetSpecialInputHandlers( + const std::string& input_name) { + return nullptr; + } + virtual AttributeHandlerFn GetSpecialAttributeHandlers( + const std::string& input_name) { return nullptr; } }; @@ -558,6 +565,12 @@ ir::AttributeMap OpTranscriber::TranslateOpAttribute( ir::AttributeMap attribute_map = {}; for (const auto& info : op_attr_infos) { + if (auto handler = this->GetSpecialAttributeHandlers(info.name)) { + auto new_attr = handler(ctx, op_desc, info); + attribute_map[info.name] = new_attr; + continue; + } + auto legacy_attr_name = op_normalizer.GetLegacyAttrName(op_desc.Type(), info.name); VLOG(10) << "[op: " << op_desc.Type() @@ -885,7 +898,8 @@ ir::OpResult TranslateDropOutStateIn(ir::IrContext* ctx, // `rnn` has an aditional input in dynamic graph struct RnnOpTranscriber : public OpTranscriber { - InputHandleFn GetSpecialInputHandlers(std::string input_name) override { + InputHandlerFn GetSpecialInputHandlers( + const std::string& input_name) override { if (input_name != "dropout_state_in") { return nullptr; } @@ -1207,7 +1221,8 @@ ir::OpResult TranslateNumClassesForOneHot(ir::IrContext* ctx, } struct OneHotTranscriber : public OpTranscriber { - InputHandleFn GetSpecialInputHandlers(std::string input_name) override { + InputHandlerFn GetSpecialInputHandlers( + const std::string& input_name) override { if (input_name != "num_classes") { return nullptr; } @@ -1215,21 +1230,53 @@ struct OneHotTranscriber : public OpTranscriber { }; }; +ir::Attribute TranslateReduceAll(ir::IrContext* ctx, + const OpDesc& op_desc, + const OpAttributeInfo& attr_info) { + bool reduce_all = false; + if (op_desc.HasAttr("reduce_all")) { + reduce_all = paddle::get(op_desc.GetAttr("reduce_all")); + } + + if (reduce_all) { + return ir::ArrayAttribute::get(ctx, std::vector{}); + } + + auto& attribute_translator = AttributeTranslator::instance(); + auto& op_normalizer = OpNameNormalizer::instance(); + auto legacy_attr_name = + op_normalizer.GetLegacyAttrName(op_desc.Type(), attr_info.name); + paddle::framework::Attribute dims = op_desc.GetAttr(legacy_attr_name); + return attribute_translator(attr_info.type_name, dims); +} + +struct ReduceOpTranscriber : public OpTranscriber { + AttributeHandlerFn GetSpecialAttributeHandlers( + const std::string& input_name) override { + if (input_name != "axis") { + return nullptr; + } + return TranslateReduceAll; + } +}; + OpTranslator::OpTranslator() { general_handler = OpTranscriber(); + special_handlers["add_n"] = AddNOpTranscriber(); + special_handlers["assign_value"] = AssignValueOpTranscriber(); + special_handlers["cast"] = CastOpTranscriber(); special_handlers["feed"] = FeedOpTranscriber(); special_handlers["feed_with_place"] = FeedWithPlaceOpTranscriber(); special_handlers["fetch_v2"] = FetchOpTranscriber(); - special_handlers["cast"] = CastOpTranscriber(); - special_handlers["split"] = SplitOpTranscriber(); + special_handlers["increment"] = IncrementOpTranscriber(); special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber(); special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber(); - special_handlers["assign_value"] = AssignValueOpTranscriber(); - special_handlers["increment"] = IncrementOpTranscriber(); + special_handlers["one_hot_v2"] = OneHotTranscriber(); + special_handlers["reduce_all"] = ReduceOpTranscriber(); + special_handlers["reduce_any"] = ReduceOpTranscriber(); special_handlers["rnn"] = RnnOpTranscriber(); special_handlers["shaddow_output"] = ShaddowOutputOpTranscriber(); - special_handlers["one_hot_v2"] = OneHotTranscriber(); - special_handlers["add_n"] = AddNOpTranscriber(); + special_handlers["split"] = SplitOpTranscriber(); special_handlers["sum"] = AddNOpTranscriber(); } diff --git a/test/ir/new_ir/test_special_op_translator.py b/test/ir/new_ir/test_special_op_translator.py index 68f2e02648ba9db420442d5dc8701870e789605c..7f9f00527e9f2facd4e8c1080627448d4a9cff42 100644 --- a/test/ir/new_ir/test_special_op_translator.py +++ b/test/ir/new_ir/test_special_op_translator.py @@ -160,5 +160,39 @@ class TestOneHotOpTranscriber(unittest.TestCase): _ = ir.translate_to_new_ir(main_program.desc) +class TestReduceOpTranscriber(unittest.TestCase): + def test_reduce_all(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + exe = paddle.static.Executor(place) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + arr = np.ones([2, 2], dtype="float32") + x = paddle.to_tensor(arr, dtype='int32') + out1 = paddle.all(x) + + out = exe.run(main_program, {}, fetch_list=[out1.name]) + np.testing.assert_array_equal(out[0], np.all(arr)) + + def test_with_axis(self): + place = core.Place() + place.set_place(paddle.CPUPlace()) + exe = paddle.static.Executor(place) + + new_scope = paddle.static.Scope() + main_program = paddle.static.Program() + with paddle.static.scope_guard(new_scope): + with paddle.static.program_guard(main_program): + arr = np.ones([2, 2], dtype="float32") + x = paddle.to_tensor(arr, dtype='int32') + out1 = paddle.all(x, axis=0) + + out = exe.run(main_program, {}, fetch_list=[out1.name]) + np.testing.assert_array_equal(out[0], np.all(arr, axis=0)) + + if __name__ == "__main__": unittest.main() diff --git a/test/white_list/new_ir_op_test_white_list b/test/white_list/new_ir_op_test_white_list index 953a188451279591b1472b3debcc6da53843acae..8a48702edcb5fb2bccb62c9730fc6047824e2d8c 100644 --- a/test/white_list/new_ir_op_test_white_list +++ b/test/white_list/new_ir_op_test_white_list @@ -139,6 +139,7 @@ test_prior_box_op test_psroi_pool_op test_put_along_axis_op test_range +test_reduce_op test_reverse_op test_roi_align_op test_roi_pool_op