未验证 提交 98244a9a 编写于 作者: Z zyfncg 提交者: GitHub

Support intermediate for Sparse API (#40840)

* support intermediate for saprse api

* close intermediate in yaml

* fix dygraph_api dep for eager
上级 c12f7d48
set(eager_deps phi_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node custom_operator_node) set(eager_deps phi_api phi_dygraph_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node custom_operator_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy) set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps final_dygraph_function final_dygraph_node dygraph_function dygraph_node) set(generated_deps final_dygraph_function final_dygraph_node dygraph_function dygraph_node)
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
// Phi deps // Phi deps
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_declare.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
/** /**
......
...@@ -36,7 +36,6 @@ limitations under the License. */ ...@@ -36,7 +36,6 @@ limitations under the License. */
#include "paddle/fluid/platform/dynload/dynamic_loader.h" #include "paddle/fluid/platform/dynload/dynamic_loader.h"
#include "paddle/fluid/string/string_helper.h" #include "paddle/fluid/string/string_helper.h"
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
#include "paddle/phi/api/lib/api_declare.h"
#include "paddle/phi/api/lib/ext_compat_utils.h" #include "paddle/phi/api/lib/ext_compat_utils.h"
#include "paddle/phi/api/lib/utils/tensor_utils.h" #include "paddle/phi/api/lib/utils/tensor_utils.h"
#include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/compat/convert_utils.h"
......
...@@ -17,12 +17,8 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py) ...@@ -17,12 +17,8 @@ set(api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/api.h) set(api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/include/api.h)
set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/api.cc) set(api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/api.cc)
set(dygraph_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/dygraph_api.h)
set(dygraph_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/dygraph_api.cc)
set(api_header_file_tmp ${api_header_file}.tmp) set(api_header_file_tmp ${api_header_file}.tmp)
set(api_source_file_tmp ${api_source_file}.tmp) set(api_source_file_tmp ${api_source_file}.tmp)
set(dygraph_api_header_file_tmp ${dygraph_api_header_file}.tmp)
set(dygraph_api_source_file_tmp ${dygraph_api_source_file}.tmp)
# backward api file # backward api file
set(bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/backward_api_gen.py) set(bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/backward_api_gen.py)
...@@ -32,6 +28,13 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc) ...@@ -32,6 +28,13 @@ set(bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/backward_api.cc)
set(bw_api_header_file_tmp ${bw_api_header_file}.tmp) set(bw_api_header_file_tmp ${bw_api_header_file}.tmp)
set(bw_api_source_file_tmp ${bw_api_source_file}.tmp) set(bw_api_source_file_tmp ${bw_api_source_file}.tmp)
# dygraph(intermediate) api file
set(im_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/intermediate_api_gen.py)
set(dygraph_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/dygraph_api.h)
set(dygraph_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/dygraph_api.cc)
set(dygraph_api_header_file_tmp ${dygraph_api_header_file}.tmp)
set(dygraph_api_source_file_tmp ${dygraph_api_source_file}.tmp)
# sparse api file # sparse api file
set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api_gen.py) set(sparse_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api_gen.py)
set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml) set(sparse_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_api.yaml)
...@@ -48,14 +51,6 @@ set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_a ...@@ -48,14 +51,6 @@ set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_a
set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp) set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp)
set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp) set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp)
# sparse bw api file
set(sparse_bw_api_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api_gen.py)
set(sparse_bw_api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/sparse_bw_api.yaml)
set(sparse_bw_api_header_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/backward/sparse_bw_api.h)
set(sparse_bw_api_source_file ${CMAKE_SOURCE_DIR}/paddle/phi/api/lib/sparse_bw_api.cc)
set(sparse_bw_api_header_file_tmp ${sparse_bw_api_header_file}.tmp)
set(sparse_bw_api_source_file_tmp ${sparse_bw_api_source_file}.tmp)
# wrapped infermeta file # wrapped infermeta file
set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py) set(wrapped_infermeta_gen_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/wrapped_infermeta_gen.py)
set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml) set(api_yaml_file ${CMAKE_SOURCE_DIR}/python/paddle/utils/code_gen/api.yaml)
...@@ -68,18 +63,14 @@ endif() ...@@ -68,18 +63,14 @@ endif()
# generate forward api # generate forward api
add_custom_command( add_custom_command(
OUTPUT ${api_header_file} ${api_source_file} ${dygraph_api_header_file} ${dygraph_api_source_file} OUTPUT ${api_header_file} ${api_source_file}
COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml COMMAND ${PYTHON_EXECUTABLE} -m pip install pyyaml
COMMAND ${PYTHON_EXECUTABLE} ${api_gen_file} COMMAND ${PYTHON_EXECUTABLE} ${api_gen_file}
--api_yaml_path ${api_yaml_file} --api_yaml_path ${api_yaml_file}
--api_header_path ${api_header_file_tmp} --api_header_path ${api_header_file_tmp}
--api_source_path ${api_source_file_tmp} --api_source_path ${api_source_file_tmp}
--dygraph_api_header_path ${dygraph_api_header_file_tmp}
--dygraph_api_source_path ${dygraph_api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_header_file_tmp} ${api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file} COMMAND ${CMAKE_COMMAND} -E copy_if_different ${api_source_file_tmp} ${api_source_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_header_file_tmp} ${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} ${dygraph_api_source_file}
COMMENT "copy_if_different ${api_header_file} ${api_source_file}" COMMENT "copy_if_different ${api_header_file} ${api_source_file}"
DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base} DEPENDS ${api_yaml_file} ${api_gen_file} ${api_gen_base}
VERBATIM) VERBATIM)
...@@ -123,6 +114,19 @@ add_custom_command( ...@@ -123,6 +114,19 @@ add_custom_command(
DEPENDS ${sparse_bw_api_yaml_file} ${sparse_bw_api_gen_file} ${api_gen_base} ${api_gen_file} ${sparse_api_gen_file} ${bw_api_gen_file} DEPENDS ${sparse_bw_api_yaml_file} ${sparse_bw_api_gen_file} ${api_gen_base} ${api_gen_file} ${sparse_api_gen_file} ${bw_api_gen_file}
VERBATIM) VERBATIM)
# generate dygraph(intermediate) api
add_custom_command(
OUTPUT ${dygraph_api_header_file} ${dygraph_api_source_file}
COMMAND ${PYTHON_EXECUTABLE} ${im_api_gen_file}
--api_yaml_path ${api_yaml_file}
--sparse_api_yaml_path ${sparse_api_yaml_file}
--dygraph_api_header_path ${dygraph_api_header_file_tmp}
--dygraph_api_source_path ${dygraph_api_source_file_tmp}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_header_file_tmp} ${dygraph_api_header_file}
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${dygraph_api_source_file_tmp} ${dygraph_api_source_file}
DEPENDS ${api_yaml_file} ${sparse_api_yaml_file} ${im_api_gen_file} ${api_gen_base} ${api_gen_file}
VERBATIM)
# generate wrapped infermeta # generate wrapped infermeta
add_custom_command( add_custom_command(
OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file} OUTPUT ${wrapped_infermeta_header_file} ${wrapped_infermeta_source_file}
...@@ -144,9 +148,9 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne ...@@ -144,9 +148,9 @@ cc_library(api_custom_impl SRCS api_custom_impl.cc DEPS phi_tensor_raw phi kerne
cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform) cc_library(sparse_api_custom_impl SRCS sparse_api_custom_impl.cc DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl) cc_library(phi_function_api SRCS ${api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform api_custom_impl)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform)
cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl) cc_library(phi_bw_function_api SRCS ${bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils backward_infermeta phi_data_transform phi_function_api api_custom_impl)
cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl) cc_library(sparse_api SRCS ${sparse_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api_custom_impl)
cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl) cc_library(sparse_bw_api SRCS ${sparse_bw_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils sparse_api sparse_api_custom_impl)
cc_library(phi_dygraph_api SRCS ${dygraph_api_source_file} DEPS phi_tensor_raw phi kernel_dispatch api_gen_utils phi_data_transform phi_function_api sparse_api)
cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api) cc_library(phi_tensor SRCS tensor_method.cc DEPS phi_tensor_raw phi_function_api api_gen_utils kernel_dispatch infermeta sparse_api)
...@@ -15,7 +15,6 @@ limitations under the License. */ ...@@ -15,7 +15,6 @@ limitations under the License. */
#include "paddle/phi/api/lib/api_custom_impl.h" #include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
......
...@@ -17,5 +17,5 @@ limitations under the License. */ ...@@ -17,5 +17,5 @@ limitations under the License. */
// api symbols declare, remove in the future // api symbols declare, remove in the future
#include "paddle/phi/api/lib/api_registry.h" #include "paddle/phi/api/lib/api_registry.h"
PD_DECLARE_API(Math); // PD_DECLARE_API(Math);
PD_DECLARE_API(SparseApi); // PD_DECLARE_API(SparseApi);
...@@ -16,7 +16,6 @@ limitations under the License. */ ...@@ -16,7 +16,6 @@ limitations under the License. */
#include <memory> #include <memory>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
...@@ -210,5 +209,3 @@ Tensor to_dense_impl(const Tensor& x) { ...@@ -210,5 +209,3 @@ Tensor to_dense_impl(const Tensor& x) {
} // namespace sparse } // namespace sparse
} // namespace experimental } // namespace experimental
} // namespace paddle } // namespace paddle
PD_REGISTER_API(SparseApi);
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/utils/allocator.h" #include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/api/lib/utils/storage.h"
......
...@@ -137,7 +137,6 @@ def source_include(header_file_path): ...@@ -137,7 +137,6 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h" #include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
...@@ -153,12 +152,6 @@ def source_include(header_file_path): ...@@ -153,12 +152,6 @@ def source_include(header_file_path):
""" """
def api_register():
return """
PD_REGISTER_API(Math);
"""
def api_namespace(): def api_namespace():
return (""" return ("""
namespace paddle { namespace paddle {
...@@ -171,15 +164,12 @@ namespace experimental { ...@@ -171,15 +164,12 @@ namespace experimental {
""") """)
def generate_api(api_yaml_path, header_file_path, source_file_path, def generate_api(api_yaml_path, header_file_path, source_file_path):
dygraph_header_file_path, dygraph_source_file_path):
with open(api_yaml_path, 'r') as f: with open(api_yaml_path, 'r') as f:
apis = yaml.load(f, Loader=yaml.FullLoader) apis = yaml.load(f, Loader=yaml.FullLoader)
header_file = open(header_file_path, 'w') header_file = open(header_file_path, 'w')
source_file = open(source_file_path, 'w') source_file = open(source_file_path, 'w')
dygraph_header_file = open(dygraph_header_file_path, 'w')
dygraph_source_file = open(dygraph_source_file_path, 'w')
namespace = api_namespace() namespace = api_namespace()
...@@ -191,41 +181,20 @@ def generate_api(api_yaml_path, header_file_path, source_file_path, ...@@ -191,41 +181,20 @@ def generate_api(api_yaml_path, header_file_path, source_file_path,
source_file.write(source_include(include_header_file)) source_file.write(source_include(include_header_file))
source_file.write(namespace[0]) source_file.write(namespace[0])
dygraph_header_file.write("#pragma once\n")
dygraph_header_file.write(header_include())
dygraph_header_file.write(namespace[0])
dygraph_include_header_file = "paddle/phi/api/lib/dygraph_api.h"
dygraph_source_file.write(source_include(dygraph_include_header_file))
dygraph_source_file.write(namespace[0])
for api in apis: for api in apis:
foward_api = ForwardAPI(api) foward_api = ForwardAPI(api)
if foward_api.is_dygraph_api: if foward_api.is_dygraph_api:
dygraph_header_file.write(foward_api.gene_api_declaration())
dygraph_source_file.write(foward_api.gene_api_code())
foward_api.is_dygraph_api = False foward_api.is_dygraph_api = False
header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code())
else:
header_file.write(foward_api.gene_api_declaration()) header_file.write(foward_api.gene_api_declaration())
source_file.write(foward_api.gene_api_code()) source_file.write(foward_api.gene_api_code())
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
dygraph_header_file.write(namespace[1])
dygraph_source_file.write(namespace[1])
source_file.write(api_register())
header_file.close() header_file.close()
source_file.close() source_file.close()
dygraph_header_file.close()
dygraph_source_file.close()
def main(): def main():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
...@@ -245,26 +214,13 @@ def main(): ...@@ -245,26 +214,13 @@ def main():
help='output of generated api source code file', help='output of generated api source code file',
default='paddle/phi/api/lib/api.cc') default='paddle/phi/api/lib/api.cc')
parser.add_argument(
'--dygraph_api_header_path',
help='output of generated dygraph api header code file',
default='paddle/phi/api/lib/dygraph_api.h')
parser.add_argument(
'--dygraph_api_source_path',
help='output of generated dygraph api source code file',
default='paddle/phi/api/lib/dygraph_api.cc')
options = parser.parse_args() options = parser.parse_args()
api_yaml_path = options.api_yaml_path api_yaml_path = options.api_yaml_path
header_file_path = options.api_header_path header_file_path = options.api_header_path
source_file_path = options.api_source_path source_file_path = options.api_source_path
dygraph_header_file_path = options.dygraph_api_header_path
dygraph_source_file_path = options.dygraph_api_source_path
generate_api(api_yaml_path, header_file_path, source_file_path, generate_api(api_yaml_path, header_file_path, source_file_path)
dygraph_header_file_path, dygraph_source_file_path)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -146,7 +146,6 @@ def source_include(header_file_path): ...@@ -146,7 +146,6 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h" #include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import yaml
import argparse
import re
from api_gen import ForwardAPI
from sparse_api_gen import SparseAPI
def header_include():
return """
#include <tuple>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/common/scalar_array.h"
#include "paddle/utils/optional.h"
"""
def source_include(header_file_path):
return f"""#include "{header_file_path}"
#include <memory>
#include "glog/logging.h"
#include "paddle/phi/api/lib/api_custom_impl.h"
#include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/sparse_api_custom_impl.h"
#include "paddle/phi/api/lib/utils/storage.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
"""
def api_namespace():
return ("""
namespace paddle {
namespace experimental {
""", """
} // namespace experimental
} // namespace paddle
""")
def sparse_namespace():
return ("""
namespace sparse {
""", """
} // namespace sparse
""")
def generate_intermediate_api(api_yaml_path, sparse_api_yaml_path,
dygraph_header_file_path,
dygraph_source_file_path):
dygraph_header_file = open(dygraph_header_file_path, 'w')
dygraph_source_file = open(dygraph_source_file_path, 'w')
namespace = api_namespace()
sparse_namespace_pair = sparse_namespace()
dygraph_header_file.write("#pragma once\n")
dygraph_header_file.write(header_include())
dygraph_header_file.write(namespace[0])
dygraph_include_header_file = "paddle/phi/api/lib/dygraph_api.h"
dygraph_source_file.write(source_include(dygraph_include_header_file))
dygraph_source_file.write(namespace[0])
with open(api_yaml_path, 'r') as f:
apis = yaml.load(f, Loader=yaml.FullLoader)
for api in apis:
foward_api = ForwardAPI(api)
if foward_api.is_dygraph_api:
dygraph_header_file.write(foward_api.gene_api_declaration())
dygraph_source_file.write(foward_api.gene_api_code())
dygraph_header_file.write(sparse_namespace_pair[0])
dygraph_source_file.write(sparse_namespace_pair[0])
with open(sparse_api_yaml_path, 'r') as f:
sparse_apis = yaml.load(f, Loader=yaml.FullLoader)
for api in sparse_apis:
sparse_api = SparseAPI(api)
if sparse_api.is_dygraph_api:
print(sparse_api.api)
dygraph_header_file.write(sparse_api.gene_api_declaration())
dygraph_source_file.write(sparse_api.gene_api_code())
dygraph_header_file.write(sparse_namespace_pair[1])
dygraph_header_file.write(namespace[1])
dygraph_source_file.write(sparse_namespace_pair[1])
dygraph_source_file.write(namespace[1])
dygraph_header_file.close()
dygraph_source_file.close()
def main():
parser = argparse.ArgumentParser(
description='Generate PaddlePaddle C++ Sparse API files')
parser.add_argument(
'--api_yaml_path',
help='path to api yaml file',
default='python/paddle/utils/code_gen/api.yaml')
parser.add_argument(
'--sparse_api_yaml_path',
help='path to sparse api yaml file',
default='python/paddle/utils/code_gen/sparse_api.yaml')
parser.add_argument(
'--dygraph_api_header_path',
help='output of generated dygraph api header code file',
default='paddle/phi/api/lib/dygraph_api.h')
parser.add_argument(
'--dygraph_api_source_path',
help='output of generated dygraph api source code file',
default='paddle/phi/api/lib/dygraph_api.cc')
options = parser.parse_args()
api_yaml_path = options.api_yaml_path
sparse_api_yaml_path = options.sparse_api_yaml_path
dygraph_header_file_path = options.dygraph_api_header_path
dygraph_source_file_path = options.dygraph_api_source_path
generate_intermediate_api(api_yaml_path, sparse_api_yaml_path,
dygraph_header_file_path,
dygraph_source_file_path)
if __name__ == '__main__':
main()
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
kernel : kernel :
func : sparse_conv3d func : sparse_conv3d
layout : x layout : x
# intermediate : rulebook
backward : conv3d_grad backward : conv3d_grad
- api : to_dense - api : to_dense
......
...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI): ...@@ -24,9 +24,6 @@ class SparseAPI(ForwardAPI):
def __init__(self, api_item_yaml): def __init__(self, api_item_yaml):
super(SparseAPI, self).__init__(api_item_yaml) super(SparseAPI, self).__init__(api_item_yaml)
def get_api_func_name(self):
return self.api
def gene_api_declaration(self): def gene_api_declaration(self):
return f""" return f"""
// {", ".join(self.outputs['names'])} // {", ".join(self.outputs['names'])}
...@@ -182,7 +179,6 @@ def source_include(header_file_path): ...@@ -182,7 +179,6 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/data_transform.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
...@@ -191,10 +187,6 @@ def source_include(header_file_path): ...@@ -191,10 +187,6 @@ def source_include(header_file_path):
""" """
def api_register():
return ""
def api_namespace(): def api_namespace():
return (""" return ("""
namespace paddle { namespace paddle {
...@@ -228,14 +220,14 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -228,14 +220,14 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
for api in apis: for api in apis:
sparse_api = SparseAPI(api) sparse_api = SparseAPI(api)
if sparse_api.is_dygraph_api:
sparse_api.is_dygraph_api = False
header_file.write(sparse_api.gene_api_declaration()) header_file.write(sparse_api.gene_api_declaration())
source_file.write(sparse_api.gene_api_code()) source_file.write(sparse_api.gene_api_code())
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
source_file.write(api_register())
header_file.close() header_file.close()
source_file.close() source_file.close()
......
...@@ -106,7 +106,6 @@ def source_include(header_file_path): ...@@ -106,7 +106,6 @@ def source_include(header_file_path):
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/lib/api_registry.h"
#include "paddle/phi/api/lib/api_gen_utils.h" #include "paddle/phi/api/lib/api_gen_utils.h"
#include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h"
#include "paddle/phi/api/lib/sparse_api_custom_impl.h" #include "paddle/phi/api/lib/sparse_api_custom_impl.h"
...@@ -114,10 +113,6 @@ def source_include(header_file_path): ...@@ -114,10 +113,6 @@ def source_include(header_file_path):
""" """
def api_register():
return ""
def api_namespace(): def api_namespace():
return (""" return ("""
namespace paddle { namespace paddle {
...@@ -157,8 +152,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path): ...@@ -157,8 +152,6 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
header_file.write(namespace[1]) header_file.write(namespace[1])
source_file.write(namespace[1]) source_file.write(namespace[1])
source_file.write(api_register())
header_file.close() header_file.close()
source_file.close() source_file.close()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册