未验证 提交 e7a38f15 编写于 作者: U umiswing 提交者: GitHub

Add macro SPCONV_WITH_CUTLASS (#54274)

上级 0b1086b9
......@@ -25,6 +25,7 @@ include_directories(
"${THIRD_PARTY_PATH}/cutlass/src/extern_cutlass/tools/util/include/")
add_definitions("-DPADDLE_WITH_CUTLASS")
add_definitions("-DSPCONV_WITH_CUTLASS=0")
if(NOT PYTHON_EXECUTABLE)
find_package(PythonInterp REQUIRED)
......
......@@ -24,7 +24,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
......@@ -134,7 +134,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi::backends::gpu::GpuMemsetAsync(
out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream());
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
bool cutlass = true;
if (dev_ctx.GetComputeCapability() < 80) cutlass = false;
......@@ -177,7 +177,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
out_channels,
out_grad_features_ptr);
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
}
#endif
const T* kernel_ptr = kernel.data<T>();
......@@ -195,7 +195,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
T* tmp_d_x_ptr = d_x_features_ptr + offsets[i] * in_channels;
T* tmp_d_kernel_ptr = d_kernel_ptr + i * in_channels * out_channels;
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
if (cutlass) {
const IntT* gather_x_indices = rulebook_ptr + offsets[i];
const IntT* scatter_x_indices = rulebook_ptr + offsets[i];
......@@ -266,13 +266,13 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
tmp_kernel_ptr,
static_cast<T>(0),
tmp_d_x_ptr);
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
}
#endif
}
// 4. scatter
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
if (!cutlass) {
#endif
phi::funcs::sparse::ScatterV2<T>(dev_ctx,
......@@ -284,7 +284,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
in_channels,
2,
x_grad_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
}
#endif
}
......
......@@ -23,7 +23,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
#include "paddle/phi/kernels/sparse/gpu/gather_gemm_scatter.h"
#endif
......@@ -159,7 +159,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter);
}
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
bool mixed_precision = dev_ctx.GetComputeCapability() >= 75 &&
dev_ctx.GetComputeCapability() < 80 &&
std::is_same<T, float>::value;
......@@ -273,7 +273,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
out_channels,
1,
out_values_ptr);
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
}
#endif
}
......
......@@ -14,7 +14,7 @@
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
#include "cutlass/arch/mma.h"
#include "cutlass/device_kernel.h"
#include "cutlass/epilogue/thread/linear_combination.h"
......
......@@ -28,13 +28,13 @@ class GatherGemmScatterEmitOperationKindLibrary(EmitOperationKindLibrary):
self.emitters = {
OperationKind.Gemm: EmitGatherGemmScatterConfigurationLibrary
}
self.header_template = "#pragma once\n#ifdef PADDLE_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.header_template = "#pragma once\n#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS\n#include \"paddle/phi/kernels/sparse/gpu/cutlass_generator/common.h\"\n"
self.configuration_header_template = """
/*
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
"""
self.entry_template = ""
self.configuration_prototype_template = ""
......
......@@ -232,7 +232,7 @@ class EmitGatherGemmScatterConfigurationLibrary(EmitGemmConfigurationLibrary):
Generated by gemm_operation.py - Do not edit.
*/
#pragma once
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
"""
self.namespace_template = """
......
......@@ -14,7 +14,7 @@
#pragma once
#include <type_traits>
#ifdef PADDLE_WITH_CUTLASS
#if defined(PADDLE_WITH_CUTLASS) && SPCONV_WITH_CUTLASS
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/autotune/auto_tune_base.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册