diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index acce7781a23e911609507e501c1307c710980a95..328b7fc74eb190c17d2c708dc7beb927ba78ec01 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1049,13 +1049,44 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_all_op_names", []() { - std::vector op_names; - for (auto &iter : OpInfoMap::Instance().map()) { - op_names.emplace_back(iter.first); - } - return op_names; - }); + m.def( + "get_all_op_names", + [](const std::string &lib) { + std::vector op_names; + for (auto &iter : OpInfoMap::Instance().map()) { + op_names.emplace_back(iter.first); + } + if (lib == "phi") { + std::vector ops_with_phi_kernel; + for (const auto &op_name : op_names) { + if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( + op_name)) { + ops_with_phi_kernel.emplace_back(op_name); + } + } + return ops_with_phi_kernel; + } else if (lib == "fluid") { + std::vector ops_with_fluid_kernel; + auto all_fluid_op_kernels = + paddle::framework::OperatorWithKernel::AllOpKernels(); + for (const auto &op_name : op_names) { + if (all_fluid_op_kernels.find(op_name) != + all_fluid_op_kernels.end()) { + ops_with_fluid_kernel.emplace_back(op_name); + } + } + return ops_with_fluid_kernel; + } else { + return op_names; + } + }, + py::arg("lib") = "all", + R"DOC( + Return the operator names in paddle. + + Args: + lib[string]: the library contains corresponding OpKernel, could be 'phi', 'fluid' and 'all'. Default value is 'all'. + )DOC"); m.def("get_op_attrs_default_value", [](py::bytes byte_name) -> paddle::framework::AttributeMap { std::string op_type = byte_name; diff --git a/python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py b/python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py similarity index 76% rename from python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py rename to python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py index a429717bdaf37b3724820d3e074c38a216634cdf..88c0c3700ea239fdae74a441dfd56452714565d5 100644 --- a/python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py +++ b/python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py @@ -39,5 +39,19 @@ class TestGetAllRegisteredOpKernels(unittest.TestCase): self.assertTrue(core._get_all_register_op_kernels()['sign']) +class TestGetAllOpNames(unittest.TestCase): + + def test_get_all_op_names(self): + all_op_names = core.get_all_op_names() + all_op_with_phi_kernels = core.get_all_op_names("phi") + all_op_with_fluid_kernels = core.get_all_op_names("fluid") + + self.assertTrue( + len(all_op_names) > len( + set(all_op_with_phi_kernels) | set(all_op_with_fluid_kernels))) + self.assertTrue("scale" in all_op_with_phi_kernels) + self.assertTrue("scale" in all_op_with_phi_kernels) + + if __name__ == '__main__': unittest.main()