From 8d4f26139a9e57b47b1dd7ba9666c06fdeab0531 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Tue, 6 Sep 2022 10:32:01 +0800 Subject: [PATCH] add op count by lib method (#45680) --- paddle/fluid/pybind/pybind.cc | 45 ++++++++++++++++--- ....py => test_get_all_op_or_kernel_names.py} | 14 ++++++ 2 files changed, 52 insertions(+), 7 deletions(-) rename python/paddle/fluid/tests/unittests/{test_get_all_registered_op_kernels.py => test_get_all_op_or_kernel_names.py} (76%) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index acce7781a23..328b7fc74eb 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -1049,13 +1049,44 @@ All parameter, weight, gradient are variables in Paddle. } return ret_values; }); - m.def("get_all_op_names", []() { - std::vector op_names; - for (auto &iter : OpInfoMap::Instance().map()) { - op_names.emplace_back(iter.first); - } - return op_names; - }); + m.def( + "get_all_op_names", + [](const std::string &lib) { + std::vector op_names; + for (auto &iter : OpInfoMap::Instance().map()) { + op_names.emplace_back(iter.first); + } + if (lib == "phi") { + std::vector ops_with_phi_kernel; + for (const auto &op_name : op_names) { + if (phi::KernelFactory::Instance().HasCompatiblePhiKernel( + op_name)) { + ops_with_phi_kernel.emplace_back(op_name); + } + } + return ops_with_phi_kernel; + } else if (lib == "fluid") { + std::vector ops_with_fluid_kernel; + auto all_fluid_op_kernels = + paddle::framework::OperatorWithKernel::AllOpKernels(); + for (const auto &op_name : op_names) { + if (all_fluid_op_kernels.find(op_name) != + all_fluid_op_kernels.end()) { + ops_with_fluid_kernel.emplace_back(op_name); + } + } + return ops_with_fluid_kernel; + } else { + return op_names; + } + }, + py::arg("lib") = "all", + R"DOC( + Return the operator names in paddle. + + Args: + lib[string]: the library contains corresponding OpKernel, could be 'phi', 'fluid' and 'all'. Default value is 'all'. + )DOC"); m.def("get_op_attrs_default_value", [](py::bytes byte_name) -> paddle::framework::AttributeMap { std::string op_type = byte_name; diff --git a/python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py b/python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py similarity index 76% rename from python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py rename to python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py index a429717bdaf..88c0c3700ea 100644 --- a/python/paddle/fluid/tests/unittests/test_get_all_registered_op_kernels.py +++ b/python/paddle/fluid/tests/unittests/test_get_all_op_or_kernel_names.py @@ -39,5 +39,19 @@ class TestGetAllRegisteredOpKernels(unittest.TestCase): self.assertTrue(core._get_all_register_op_kernels()['sign']) +class TestGetAllOpNames(unittest.TestCase): + + def test_get_all_op_names(self): + all_op_names = core.get_all_op_names() + all_op_with_phi_kernels = core.get_all_op_names("phi") + all_op_with_fluid_kernels = core.get_all_op_names("fluid") + + self.assertTrue( + len(all_op_names) > len( + set(all_op_with_phi_kernels) | set(all_op_with_fluid_kernels))) + self.assertTrue("scale" in all_op_with_phi_kernels) + self.assertTrue("scale" in all_op_with_phi_kernels) + + if __name__ == '__main__': unittest.main() -- GitLab