From 03b875a85119bd8f19246c0e7f1ac49c5a33e1c3 Mon Sep 17 00:00:00 2001 From: Shang Zhizhou Date: Fri, 18 Feb 2022 10:33:30 +0800 Subject: [PATCH] add tool: print kernel signaturs (#39670) * add tool: print kernel signaturs * fix windows compile --- paddle/fluid/pybind/CMakeLists.txt | 6 ++ .../pybind/kernel_signature_generator.cc | 72 +++++++++++++++++++ 2 files changed, 78 insertions(+) create mode 100644 paddle/fluid/pybind/kernel_signature_generator.cc diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index b1fe9f99b5..3453cff30f 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -137,6 +137,11 @@ if(WITH_PYTHON) target_link_libraries(op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) add_executable(eager_op_function_generator eager_op_function_generator.cc) target_link_libraries(eager_op_function_generator ${OP_FUNCTION_GENERETOR_DEPS}) + add_executable(kernel_signature_generator kernel_signature_generator.cc) + target_link_libraries(kernel_signature_generator ${OP_FUNCTION_GENERETOR_DEPS}) + if(WIN32) + target_link_libraries(kernel_signature_generator shlwapi.lib) + endif() get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) target_link_libraries(op_function_generator ${os_dependency_modules}) @@ -144,6 +149,7 @@ if(WITH_PYTHON) if(WITH_ROCM) target_link_libraries(op_function_generator ${ROCM_HIPRTC_LIB}) target_link_libraries(eager_op_function_generator ${ROCM_HIPRTC_LIB}) + target_link_libraries(kernel_signature_generator ${ROCM_HIPRTC_LIB}) endif() set(impl_file ${CMAKE_SOURCE_DIR}/paddle/fluid/pybind/op_function_impl.h) diff --git a/paddle/fluid/pybind/kernel_signature_generator.cc b/paddle/fluid/pybind/kernel_signature_generator.cc new file mode 100644 index 0000000000..617525ba16 --- /dev/null +++ b/paddle/fluid/pybind/kernel_signature_generator.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/pten_utils.h" +#include "paddle/fluid/pybind/pybind.h" // NOLINT +#include "paddle/pten/core/compat/op_utils.h" +#include "paddle/pten/core/kernel_factory.h" +#include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/kernels/declarations.h" + +// print names of kernel function params with json format: +// { +// "norm":{ +// "inputs":[ +// "X" +// ], +// "attrs":[ +// "axis", +// "epsilon", +// "is_test" +// ], +// "outputs":[ +// "Norm", +// "Out" +// ] +// }, +// ... +// } +int main(int argc, char **argv) { + paddle::framework::InitDefaultKernelSignatureMap(); + auto &kernel_signature_map = pten::DefaultKernelSignatureMap::Instance(); + auto &kernel_factory = pten::KernelFactory::Instance(); + std::cout << "{"; + for (const auto &op_kernel_pair : kernel_factory.kernels()) { + if (kernel_signature_map.Has(op_kernel_pair.first)) { + std::cout << "\"" << op_kernel_pair.first << "\":{"; + auto &args = kernel_signature_map.Get(op_kernel_pair.first).args; + std::cout << "\"inputs\":["; + for (auto name : std::get<0>(args)) { + std::cout << "\"" << name << "\","; + } + if (std::get<0>(args).size() > 0) std::cout << "\b"; + std::cout << "],\"attrs\":["; + for (auto name : std::get<1>(args)) { + std::cout << "\"" << name << "\","; + } + if (std::get<1>(args).size() > 0) std::cout << "\b"; + std::cout << "],\"outputs\":["; + for (auto name : std::get<2>(args)) { + std::cout << "\"" << name << "\","; + } + if (std::get<2>(args).size() > 0) std::cout << "\b"; + std::cout << "]},"; + } + } + std::cout << "\b}" << std::endl; + return 0; +} -- GitLab