未验证 提交 f5476dad 编写于 作者: U umiswing 提交者: GitHub

Use copy_if_different to avoid recompilation of generated cutlass (#53531)

kernels.
上级 d91d758d
......@@ -42,20 +42,33 @@ ExternalProject_Add(
INSTALL_COMMAND ""
TEST_COMMAND "")
set(gemm_operations_file
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/all_gemm_operations.h
)
set(configurations_file
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/configurations.h
)
set(tmp_gemm_operations_file ${gemm_operations_file}.tmp)
set(tmp_configurations_file ${configurations_file}.tmp)
add_custom_target(
cutlass_codegen
COMMAND
rm -rf
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build
COMMAND
mkdir -p
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm
COMMAND
${PYTHON_EXECUTABLE} -B
${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/gather_gemm_scatter_generator.py
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/library/scripts/"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator/build"
"${CMAKE_SOURCE_DIR}/paddle/phi/kernels/sparse/gpu/cutlass_generator"
"${CMAKE_CUDA_COMPILER_VERSION}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_gemm_operations_file}
${gemm_operations_file}
COMMAND
${CMAKE_COMMAND} -E echo
"copy_if_different ${tmp_gemm_operations_file} to ${gemm_operations_file}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_configurations_file}
${configurations_file}
COMMAND
${CMAKE_COMMAND} -E echo
"copy_if_different ${tmp_configurations_file} to ${configurations_file}"
VERBATIM)
add_library(cutlass INTERFACE)
......
......@@ -29,6 +29,13 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
}
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.configuration_header_template = """
/*
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
"""
self.entry_template = ""
self.configuration_prototype_template = ""
self.configuration_template = ""
......@@ -58,11 +65,23 @@ namespace sparse {
self.top_level_path = os.path.join(
self.operation_path,
"all_%s_operations.h" % OperationKindNames[self.kind],
"all_%s_operations.h.tmp" % OperationKindNames[self.kind],
)
self.configuration_path = os.path.join(
self.operation_path, "configurations.h.tmp"
).replace('\\', '/')
self.top_level_file = open(self.top_level_path, "w")
self.top_level_file.write(self.header_template)
self.top_level_file.write(
'#include "' + self.operation_path + '/' + 'configurations.h"\n'
)
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(self.configuration_header_template)
self.configuration_file.write(self.namespace_template)
self.configuration_file.close()
self.source_files = [
self.top_level_path,
......@@ -105,14 +124,6 @@ launchKernel<"""
+ "<>>,"
)
self.top_level_file.write(
'#include "'
+ self.operation_path
+ '/'
+ configuration_name
+ '.h"\n'
)
def __exit__(self, exception_type, exception_value, traceback):
self.top_level_file.write(
SubstituteTemplate(
......@@ -137,6 +148,10 @@ launchKernel<"""
self.top_level_file.write(self.epilogue_template)
self.top_level_file.close()
self.configuration_file = open(self.configuration_path, "a")
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()
class GatherGemmScatterManifest(Manifest):
def emit(self, target=GeneratorTarget.Library):
......
......@@ -207,7 +207,7 @@ class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary):
def __init__(self, operation_path, configuration_name):
self.configuration_name = configuration_name
self.configuration_path = os.path.join(
operation_path, "%s.h" % configuration_name
operation_path, "configurations.h.tmp"
).replace('\\', '/')
self.instance_emitter = {
......@@ -246,9 +246,7 @@ namespace sparse {
"""
def __enter__(self):
self.configuration_file = open(self.configuration_path, "w")
self.configuration_file.write(self.header_template)
self.configuration_file.write(self.separator)
self.configuration_file = open(self.configuration_path, "a")
self.includes = collections.OrderedDict([])
self.instance_definitions = []
......@@ -259,22 +257,10 @@ namespace sparse {
def __exit__(self, exception_type, exception_value, traceback):
# Write includes
for incl, _ in self.includes.items():
include_statement = "#include \"%s\"\n" % incl
self.configuration_file.write(include_statement)
self.configuration_file.write(self.separator)
self.configuration_file.write(self.namespace_template)
# Write instance definitions in top-level namespace
for instance_definition in self.instance_definitions:
self.configuration_file.write(instance_definition)
for instance_wrapper in self.instance_wrappers:
self.configuration_file.write(instance_wrapper)
self.configuration_file.write(self.epilogue_template)
self.configuration_file.close()
......
......@@ -18,7 +18,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/build/generated/gemm/all_gemm_operations.h"
#include "paddle/phi/kernels/sparse/gpu/cutlass_generator/generated/gemm/all_gemm_operations.h"
namespace phi {
namespace sparse {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册