提交 e50205e7 编写于 作者: S sabreshao

CMake refine for HIP support.

1. Add option WITH_AMD_GPU.
2. Add cmake/hip.cmake for HIP toolchain.
3. Some external module such as eigen may need HIP port.
4. Add macro hip_library/hip_binary/hip_test to cmake/generic.cmake.
5. Add one HIP source concat.hip.cu as an example. Each .cu may have its corresponding .hip.cu.
上级 45c988d8
......@@ -70,9 +70,6 @@ if(NOT CMAKE_BUILD_TYPE)
FORCE)
endif()
if(WITH_AMD_GPU)
endif()
if(ANDROID OR IOS)
if(ANDROID)
if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16")
......
INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3)
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
if(WITH_AMD_GPU)
ExternalProject_Add(
......
......@@ -27,9 +27,6 @@ endif(WITH_TESTING)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
# Disable optimization since one eigen symbol will be removed in math_function.cu
#list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
......
......@@ -12,6 +12,8 @@ function(op_library TARGET)
set(OP_LIBRARY ${TARGET} ${OP_LIBRARY} PARENT_SCOPE)
set(cc_srcs)
set(cu_srcs)
set(hip_cu_srcs)
set(miopen_hip_cc_srcs)
set(cu_cc_srcs)
set(cudnn_cu_cc_srcs)
set(CUDNN_FILE)
......@@ -36,10 +38,19 @@ function(op_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
list(APPEND hip_cu_srcs ${TARGET}.hip.cu)
endif()
string(REPLACE "_op" "_cudnn_op" CUDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${CUDNN_FILE}.cu.cc)
list(APPEND cudnn_cu_cc_srcs ${CUDNN_FILE}.cu.cc)
endif()
if(WITH_AMD_GPU)
string(REPLACE "_op" "_miopen_op" MIOPEN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MIOPEN_FILE}.hip.cc)
list(APPEND miopen_hip_cc_srcs ${MIOPEN_FILE}.hip.cc)
endif()
endif()
if(WITH_MKLDNN)
string(REPLACE "_op" "_mkldnn_op" MKLDNN_FILE "${TARGET}")
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${MKLDNN_FILE}.cc)
......@@ -48,10 +59,14 @@ function(op_library TARGET)
endif()
else()
foreach(src ${op_library_SRCS})
if (${src} MATCHES ".*\\.cu$")
if (${src} MATCHES ".*\\.hip.cu$")
list(APPEND hip_cu_srcs ${src})
elseif (${src} MATCHES ".*\\.cu$")
list(APPEND cu_srcs ${src})
elseif(${src} MATCHES ".*_cudnn_op.cu.cc$")
list(APPEND cudnn_cu_cc_srcs ${src})
elseif(WITH_AMD_GPU AND ${src} MATCHES ".*_miopen_op.hip.cc$")
list(APPEND miopen_hip_cc_srcs ${src})
elseif(WITH_MKLDNN AND ${src} MATCHES ".*_mkldnn_op.cc$")
list(APPEND mkldnn_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cu.cc$")
......@@ -77,8 +92,8 @@ function(op_library TARGET)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS
${op_library_DEPS} ${op_common_deps})
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cu_srcs} ${miopen_hip_cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
......@@ -91,7 +106,7 @@ function(op_library TARGET)
endif()
endforeach()
# The registration of USE_OP, please refer to paddle/framework/op_registry.h.
# The registration of USE_OP, please refer to paddle/fluid/framework/op_registry.h.
# Note that it's enough to just adding one operator to pybind in a *_op.cc file.
# And for detail pybind information, please see generated paddle/pybind/pybind.h.
file(READ ${TARGET}.cc TARGET_CONTENT)
......@@ -117,7 +132,10 @@ function(op_library TARGET)
list(LENGTH cu_srcs cu_srcs_len)
list(LENGTH cu_cc_srcs cu_cc_srcs_len)
list(LENGTH mkldnn_cc_srcs mkldnn_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0)
list(LENGTH hip_cu_srcs hip_cu_srcs_len)
list(LENGTH miopen_hip_cc_srcs miopen_hip_cc_srcs_len)
if (${pybind_flag} EQUAL 0 AND ${mkldnn_cc_srcs_len} EQUAL 0 AND ${cu_srcs_len} EQUAL 0 AND ${cu_cc_srcs_len} EQUAL 0 AND
${hip_cu_srcs_len} EQUAL 0 AND ${miopen_hip_cc_srcs_len} EQUAL 0)
file(APPEND ${pybind_file} "USE_CPU_ONLY_OP(${TARGET});\n")
set(pybind_flag 1)
endif()
......@@ -128,6 +146,11 @@ function(op_library TARGET)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, CUDNN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MIOPEN
if (WITH_AMD_GPU AND ${miopen_hip_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MIOPEN);\n")
endif()
# pybind USE_OP_DEVICE_KERNEL for MKLDNN
if (WITH_MKLDNN AND ${mkldnn_cc_srcs_len} GREATER 0)
file(APPEND ${pybind_file} "USE_OP_DEVICE_KERNEL(${TARGET}, MKLDNN);\n")
......
......@@ -12,270 +12,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hip/hip_runtime.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first;
const T* it = nullptr;
T step = 0;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (!(val < *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first - orig;
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
int curr_offset = input_cols[segment];
int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* input_ptr = inputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset];
}
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int* output_cols,
int col_size, T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
int curr_offset = output_cols[segment];
int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input[tid_y * input_col + tid_x];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* output_ptr = outputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] =
input[tid_y * input_col + tid_x];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
framework::Vector<int> inputs_cols(num + 1);
inputs_cols[0] = 0;
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
bool sameShape = true;
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
inputs_cols[i + 1] = out_cols;
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
}
T** ins_gpu =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
hipLaunchKernelGGL((KernelConcat<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
ins_gpu, cols, out_rows, out_cols, output->data<T>());
} else {
hipLaunchKernelGGL((KernelConcat<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
out_cols, output->data<T>());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int input_row = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_row *= dim_0[i];
}
int output_col_0 = outputs[0].numel() / input_row;
int input_col = 0;
bool sameShape = true;
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(num + 1);
outputs_cols[0] = 0;
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row;
if (sameShape) {
if (t_col != output_col_0) sameShape = false;
}
input_col += t_col;
outputs_cols[i + 1] = input_col;
outputs_ptr[i] = outputs[i].data<T>();
}
T** outs_gpu =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((input_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
hipLaunchKernelGGL((KernelConcatGrad<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
} else {
hipLaunchKernelGGL((KernelConcatGrad<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
input.data<T>(), input_row, input_col, outs_col_gpu,
static_cast<int>(outputs_cols.size()), outs_gpu);
}
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
#include <hip/hip_runtime.h>
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册