未验证 提交 f5476dad 编写于 作者: U umiswing 提交者: GitHub

Use copy_if_different to avoid recompilation of generated cutlass (#53531)

kernels.
上级 d91d758d
...@@ -42,20 +42,33 @@ ExternalProject_Add( ...@@ -42,20 +42,33 @@ ExternalProject_Add(
INSTALL_COMMAND "" INSTALL_COMMAND ""
TEST_COMMAND "") TEST_COMMAND "")
set(gemm_operations_file
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/all_gemm_operations.h
)
set(configurations_file
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/configurations.h
)
set(tmp_gemm_operations_file ${gemm_operations_file}.tmp)
set(tmp_configurations_file ${configurations_file}.tmp)
add_custom_target( add_custom_target(
cutlass_codegen cutlass_codegen
COMMAND
rm -rf
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build
COMMAND
mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm
COMMAND COMMAND
${PYTHON_EXECUTABLE} -B ${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py ${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/" "${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build" "${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator"
"${CMAKE_CUDA_COMPILER_VERSION}" "${CMAKE_CUDA_COMPILER_VERSION}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_gemm_operations_file}
${gemm_operations_file}
COMMAND
${CMAKE_COMMAND} -E echo
"copy_if_different ${tmp_gemm_operations_file} to ${gemm_operations_file}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_configurations_file}
${configurations_file}
COMMAND
${CMAKE_COMMAND} -E echo
"copy_if_different ${tmp_configurations_file} to ${configurations_file}"
VERBATIM) VERBATIM)
add_library(cutlass INTERFACE) add_library(cutlass INTERFACE)
......
...@@ -29,6 +29,13 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary): ...@@ -29,6 +29,13 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
} }
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n" self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.configuration_header_template = """
/*
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
"""
self.entry_template = "" self.entry_template = ""
self.configuration_prototype_template = "" self.configuration_prototype_template = ""
self.configuration_template = "" self.configuration_template = ""
...@@ -58,11 +65,23 @@ namespace sparse { ...@@ -58,11 +65,23 @@ namespace sparse {
self.top_level_path = os.path.join( self.top_level_path = os.path.join(
self.operation_path, self.operation_path,
"all_%s_operations.h" % OperationKindNames[self.kind], "all_%s_operations.h.tmp" % OperationKindNames[self.kind],
) )
self.configuration_path = os.path.join(
self.operation_path, "configurations.h.tmp"
).replace('\\', '/')
self.top_level_file = open(self.top_level_path, "w") self.top_level_file = open(self.top_level_path, "w")
self.top_level_file.write(self.header_template) self.top_level_file.write(self.header_template)
self.top_level_file.write(
'#include "' + self.operation_path + '/' + 'configurations.h"\n'
)
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(self.configuration_header_template)
self.configuration_file.write(self.namespace_template)
self.configuration_file.close()
self.source_files = [ self.source_files = [
self.top_level_path, self.top_level_path,
...@@ -105,14 +124,6 @@ launchKernel<""" ...@@ -105,14 +124,6 @@ launchKernel<"""
+ "<>>," + "<>>,"
) )
self.top_level_file.write(
'#include "'
+ self.operation_path
+ '/'
+ configuration_name
+ '.h"\n'
)
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, exception_type, exception_value, traceback):
self.top_level_file.write( self.top_level_file.write(
SubstituteTemplate( SubstituteTemplate(
...@@ -137,6 +148,10 @@ launchKernel<""" ...@@ -137,6 +148,10 @@ launchKernel<"""
self.top_level_file.write(self.epilogue_template) self.top_level_file.write(self.epilogue_template)
self.top_level_file.close() self.top_level_file.close()
self.configuration_file = open(self.configuration_path, "a")
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()
class GatherGemmScatterManifest(Manifest): class GatherGemmScatterManifest(Manifest):
def emit(self, target=GeneratorTarget.Library): def emit(self, target=GeneratorTarget.Library):
......
...@@ -207,7 +207,7 @@ class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary): ...@@ -207,7 +207,7 @@ class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary):
def __init__(self, operation_path, configuration_name): def __init__(self, operation_path, configuration_name):
self.configuration_name = configuration_name self.configuration_name = configuration_name
self.configuration_path = os.path.join( self.configuration_path = os.path.join(
operation_path, "%s.h" % configuration_name operation_path, "configurations.h.tmp"
).replace('\\', '/') ).replace('\\', '/')
self.instance_emitter = { self.instance_emitter = {
...@@ -246,9 +246,7 @@ namespace sparse { ...@@ -246,9 +246,7 @@ namespace sparse {
""" """
def __enter__(self): def __enter__(self):
self.configuration_file = open(self.configuration_path, "w") self.configuration_file = open(self.configuration_path, "a")
self.configuration_file.write(self.header_template)
self.configuration_file.write(self.separator)
self.includes = collections.OrderedDict([]) self.includes = collections.OrderedDict([])
self.instance_definitions = [] self.instance_definitions = []
...@@ -259,22 +257,10 @@ namespace sparse { ...@@ -259,22 +257,10 @@ namespace sparse {
def __exit__(self, exception_type, exception_value, traceback): def __exit__(self, exception_type, exception_value, traceback):
# Write includes
for incl, _ in self.includes.items():
include_statement = "#include \"%s\"\n" % incl
self.configuration_file.write(include_statement)
self.configuration_file.write(self.separator)
self.configuration_file.write(self.namespace_template)
# Write instance definitions in top-level namespace # Write instance definitions in top-level namespace
for instance_definition in self.instance_definitions: for instance_definition in self.instance_definitions:
self.configuration_file.write(instance_definition) self.configuration_file.write(instance_definition)
for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close() self.configuration_file.close()
......
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h" #include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h" #include "paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/all_gemm_operations.h"
namespace phi { namespace phi {
namespace sparse { namespace sparse {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册