提交 45c988d8 编写于 作者: S sabreshao

Demostration of cmake refine for HIP support.

1. Add option WITH_AMD_GPU.
2. Add cmake/hip.cmake for HIP toolchain.
3. Some external module such as eigen may need HIP port.
4. Add macro hip_library/hip_binary/hip_test to cmake/generic.cmake.
5. Add one HIP source concat.hip.cu as an example. Each .cu may have its corresponding .hip.cu.
上级 6d4a06f2
......@@ -36,6 +36,7 @@ include(simd)
################################ Configurations #######################################
option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND})
option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF)
option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND})
option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND})
option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON)
......@@ -69,6 +70,9 @@ if(NOT CMAKE_BUILD_TYPE)
FORCE)
endif()
if(WITH_AMD_GPU)
endif()
if(ANDROID OR IOS)
if(ANDROID)
if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16")
......@@ -180,6 +184,11 @@ if(WITH_GPU)
include(cuda)
endif(WITH_GPU)
if(WITH_AMD_GPU)
find_package(HIP)
include(hip)
endif(WITH_AMD_GPU)
if(WITH_MKLML)
list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB})
endif()
......
......@@ -57,11 +57,7 @@ if(NOT WITH_GOLANG)
add_definitions(-DPADDLE_WITHOUT_GOLANG)
endif(NOT WITH_GOLANG)
if(NOT WITH_GPU)
add_definitions(-DHPPL_STUB_FUNC)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
else()
if(WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
FIND_PACKAGE(CUDA REQUIRED)
......@@ -84,7 +80,14 @@ else()
# Include cuda and cudnn
include_directories(${CUDNN_INCLUDE_DIR})
include_directories(${CUDA_TOOLKIT_INCLUDE})
endif(NOT WITH_GPU)
elseif(WITH_AMD_GPU)
add_definitions(-DPADDLE_WITH_HIP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__")
else()
add_definitions(-DHPPL_STUB_FUNC)
list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu)
endif()
if (WITH_MKLML AND MKLML_IOMP_LIB)
message(STATUS "Enable Intel OpenMP with ${MKLML_IOMP_LIB}")
......
INCLUDE(ExternalProject)
SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3)
SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3)
INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR})
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
if(WITH_AMD_GPU)
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git"
GIT_TAG 0cba03ff9f8f9f70bbd92ac5857b031aa8fed6f9
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
else()
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
INSTALL_COMMAND ""
TEST_COMMAND ""
)
endif()
if (${CMAKE_VERSION} VERSION_LESS "3.3.0")
set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c)
......
......@@ -317,6 +317,82 @@ function(nv_test TARGET_NAME)
endif()
endfunction(nv_test)
function(hip_library TARGET_NAME)
if (WITH_AMD_GPU)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(_sources ${hip_library_SRCS})
HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options})
if(_source_files)
list(REMOVE_ITEM _sources ${_source_files})
endif()
if(hip_library_SRCS)
if (hip_library_SHARED OR hip_library_shared) # build *.so
add_library(${TARGET_NAME} SHARED ${_cmake_options} ${_generated_files} ${_sources})
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP)
else()
add_library(${TARGET_NAME} STATIC ${_cmake_options} ${_generated_files} ${_sources})
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX)
target_link_libraries(${TARGET_NAME} /opt/rocm/hip/lib/libhip_hcc.so /opt/rocm/hip/lib/libhip_device.a)
find_fluid_modules(${TARGET_NAME})
endif()
if (hip_library_DEPS)
add_dependencies(${TARGET_NAME} ${hip_library_DEPS})
target_link_libraries(${TARGET_NAME} ${hip_library_DEPS})
endif()
# cpplint code style
foreach(source_file ${hip_library_SRCS})
string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file})
if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
list(APPEND hip_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h)
endif()
endforeach()
add_style_check_target(${TARGET_NAME} ${hip_library_SRCS} ${hip_library_HEADERS})
else(hip_library_SRCS)
if (hip_library_DEPS)
merge_static_libs(${TARGET_NAME} ${hip_library_DEPS})
else()
message(FATAL "Please specify source file or library in nv_library.")
endif()
endif(hip_library_SRCS)
endif()
endfunction(hip_library)
function(hip_binary TARGET_NAME)
if (WITH_AMD_GPU)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(hip_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
hip_add_executable(${TARGET_NAME} ${hip_binary_SRCS})
if(hip_binary_DEPS)
target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS})
add_dependencies(${TARGET_NAME} ${hip_binary_DEPS})
endif()
endif()
endfunction(hip_binary)
function(hip_test TARGET_NAME)
if (WITH_AMD_GPU AND WITH_TESTING)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(hip_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
set(_sources ${hip_test_SRCS})
HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options})
if(_source_files)
list(REMOVE_ITEM _sources ${_source_files})
endif()
add_executable(${TARGET_NAME} ${_cmake_options} ${_generated_files} ${_sources})
set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP)
target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags)
add_test(${TARGET_NAME} ${TARGET_NAME})
endif()
endfunction(hip_test)
function(go_library TARGET_NAME)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
......
if(NOT WITH_AMD_GPU)
return()
endif()
include_directories("/opt/rocm/include")
include_directories("/opt/rocm/hipblas/include")
include_directories("/opt/rocm/hiprand/include")
include_directories("/opt/rocm/rocrand/include")
include_directories("/opt/rocm/rccl/include")
include_directories("/opt/rocm/thrust")
list(APPEND EXTERNAL_LIBS "-L/opt/rocm/lib/ -lhip_hcc")
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++14" )
if(WITH_DSO)
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_USE_DSO")
endif(WITH_DSO)
if(WITH_DOUBLE)
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_TYPE_DOUBLE")
endif(WITH_DOUBLE)
if(WITH_TESTING)
set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING")
endif(WITH_TESTING)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG})
elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
# Disable optimization since one eigen symbol will be removed in math_function.cu
#list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE})
elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo")
list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO})
elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel")
list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL})
endif()
if("x${HCC_HOME}" STREQUAL "x")
set(HCC_HOME "/opt/rocm/hcc")
endif()
set(CMAKE_HIP_LINK_EXECUTABLE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} <FLAGS> <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> -o <TARGET> <LINK_LIBRARIES>")
set(CMAKE_HIP_CREATE_SHARED_LIBRARY "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> -o <TARGET> <LINK_LIBRARIES> -shared")
set(CMAKE_HIP_CREATE_SHARED_MODULE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} <CMAKE_CXX_LINK_FLAGS> <LINK_FLAGS> <OBJECTS> -o <TARGET> <LINK_LIBRARIES> -shared")
......@@ -76,6 +76,9 @@ function(op_library TARGET)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS
${op_library_DEPS} ${op_common_deps})
else()
cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
......
......@@ -6,6 +6,7 @@ function(math_library TARGET)
# But it handle split GPU/CPU code and link some common library.
set(cc_srcs)
set(cu_srcs)
set(hip_srcs)
set(math_common_deps device_context framework_proto)
set(multiValueArgs DEPS)
cmake_parse_arguments(math_library "${options}" "${oneValueArgs}"
......@@ -17,10 +18,15 @@ function(math_library TARGET)
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu)
list(APPEND cu_srcs ${TARGET}.cu)
endif()
if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu)
list(APPEND hip_srcs ${TARGET}.hip.cu)
endif()
list(LENGTH cc_srcs cc_srcs_len)
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps})
elseif (WITH_AMD_GPU)
hip_library(${TARGET} SRCS ${cc_srcs} ${hip_srcs} DEPS ${math_library_DEPS} ${math_common_deps})
elseif(${cc_srcs_len} GREATER 0)
cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps})
endif()
......
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "hip/hip_runtime.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__device__ T upper_bound(const T* first, T count, T val) {
const T* orig = first;
const T* it = nullptr;
T step = 0;
while (count > 0) {
it = first;
step = count / 2;
it += step;
if (!(val < *it)) {
first = ++it;
count -= step + 1;
} else {
count = step;
}
}
return first - orig;
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(input_cols, col_size, tid_x) - 1;
int curr_offset = input_cols[segment];
int curr_segment = segment;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* input_ptr = inputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y)
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * segment_width + local_col];
}
}
template <typename T>
__global__ void KernelConcat(T** inputs, const int input_col,
const int output_rows, const int output_cols,
T* output) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* input_ptr = inputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) {
output[tid_y * output_cols + tid_x] =
input_ptr[tid_y * input_col + in_offset];
}
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int* output_cols,
int col_size, T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int segment = upper_bound<int>(output_cols, col_size, tid_x) - 1;
int curr_offset = output_cols[segment];
int curr_segment = segment;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
T curr_col_offset;
while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) {
curr_offset = curr_col_offset;
++curr_segment;
}
int local_col = tid_x - curr_offset;
int segment_width = curr_col_offset - curr_offset;
T* output_ptr = outputs[curr_segment];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * segment_width + local_col] =
input[tid_y * input_col + tid_x];
}
}
template <typename T>
__global__ void KernelConcatGrad(const T* input, const int input_row,
const int input_col, const int output_cols,
T** outputs) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
double inv_input_col = 1.0 / input_col;
for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) {
int split = tid_x * inv_input_col;
int in_offset = tid_x - split * input_col;
T* output_ptr = outputs[split];
int tid_y = blockIdx.y * blockDim.y + threadIdx.y;
for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y)
output_ptr[tid_y * output_cols + in_offset] =
input[tid_y * input_col + tid_x];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, const int axis,
framework::Tensor* output) {
// TODO(zcd): Add input data validity checking
int num = input.size();
int rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int cols = input[0].numel() / rows;
int out_rows = rows, out_cols = 0;
framework::Vector<int16_t> inputs_data(num * sizeof(T*) / 2);
framework::Vector<int> inputs_cols(num + 1);
inputs_cols[0] = 0;
T** inputs_ptr = reinterpret_cast<T**>(inputs_data.data());
bool sameShape = true;
for (int i = 0; i < num; ++i) {
int t_cols = input[i].numel() / rows;
if (sameShape) {
if (t_cols != cols) sameShape = false;
}
out_cols += t_cols;
inputs_cols[i + 1] = out_cols;
inputs_ptr[i] = const_cast<T*>(input[i].data<T>());
}
T** ins_gpu =
reinterpret_cast<T**>(inputs_data.CUDAMutableData(context.GetPlace()));
const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace());
// computation
// set the thread block and grid according to CurrentDeviceId
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((out_cols + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((out_cols + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
hipLaunchKernelGGL((KernelConcat<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
ins_gpu, cols, out_rows, out_cols, output->data<T>());
} else {
hipLaunchKernelGGL((KernelConcat<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
ins_gpu, ins_col_gpu, static_cast<int>(inputs_cols.size()), out_rows,
out_cols, output->data<T>());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template <typename T>
class ConcatGradFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input, const int axis,
std::vector<framework::Tensor>& outputs) {
// TODO(zcd): Add input data validity checking
int num = outputs.size();
int input_row = 1;
auto dim_0 = outputs[0].dims();
for (int i = 0; i < axis; ++i) {
input_row *= dim_0[i];
}
int output_col_0 = outputs[0].numel() / input_row;
int input_col = 0;
bool sameShape = true;
framework::Vector<int16_t> outputs_data(num * sizeof(T*) / 2);
framework::Vector<int> outputs_cols(num + 1);
outputs_cols[0] = 0;
T** outputs_ptr = reinterpret_cast<T**>(outputs_data.data());
for (int i = 0; i < num; ++i) {
int t_col = outputs[i].numel() / input_row;
if (sameShape) {
if (t_col != output_col_0) sameShape = false;
}
input_col += t_col;
outputs_cols[i + 1] = input_col;
outputs_ptr[i] = outputs[i].data<T>();
}
T** outs_gpu =
reinterpret_cast<T**>(outputs_data.CUDAMutableData(context.GetPlace()));
const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace());
// computation
const int kThreadsPerBlock = 1024;
int block_cols = kThreadsPerBlock;
if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32.
block_cols = ((input_col + 31) >> 5) << 5;
}
int block_rows = kThreadsPerBlock / block_cols;
dim3 block_size = dim3(block_cols, block_rows, 1);
int max_threads = context.GetMaxPhysicalThreadCount();
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
int grid_cols =
std::min((input_col + block_cols - 1) / block_cols, max_blocks);
int grid_rows =
std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1));
dim3 grid_size = dim3(grid_cols, grid_rows, 1);
if (sameShape) {
hipLaunchKernelGGL((KernelConcatGrad<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
input.data<T>(), input_row, input_col, output_col_0, outs_gpu);
} else {
hipLaunchKernelGGL((KernelConcatGrad<T>), dim3(grid_size), dim3(block_size), 0, context.stream(),
input.data<T>(), input_row, input_col, outs_col_gpu,
static_cast<int>(outputs_cols.size()), outs_gpu);
}
}
};
template class ConcatFunctor<platform::CUDADeviceContext, int>;
template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatFunctor<platform::CUDADeviceContext, float>;
template class ConcatFunctor<platform::CUDADeviceContext, double>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int>;
template class ConcatGradFunctor<platform::CUDADeviceContext, int64_t>;
template class ConcatGradFunctor<platform::CUDADeviceContext, float>;
template class ConcatGradFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
if(WITH_PYTHON)
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
endif(NOT APPLE AND NOT ANDROID)
if(WITH_AMD_GPU)
hip_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB})
else()
cc_library(paddle_pybind SHARED
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
endif(NOT APPLE AND NOT ANDROID)
endif(WITH_AMD_GPU)
endif(WITH_PYTHON)
......@@ -37,6 +37,7 @@ function cmake_gen() {
-DWITH_DSO=ON
-DWITH_DOC=OFF
-DWITH_GPU=${WITH_GPU:-OFF}
-DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF}
-DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF}
-DWITH_MKL=${WITH_MKL:-ON}
-DWITH_AVX=${WITH_AVX:-OFF}
......@@ -50,6 +51,7 @@ function cmake_gen() {
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON}
-DWITH_FAST_BUNDLE_TEST=ON
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
========================================
EOF
......@@ -62,6 +64,7 @@ EOF
-DWITH_DSO=ON \
-DWITH_DOC=OFF \
-DWITH_GPU=${WITH_GPU:-OFF} \
-DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} \
-DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \
-DWITH_MKL=${WITH_MKL:-ON} \
-DWITH_AVX=${WITH_AVX:-OFF} \
......@@ -74,6 +77,7 @@ EOF
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \
-DWITH_TESTING=${WITH_TESTING:-ON} \
-DWITH_FAST_BUNDLE_TEST=ON \
-DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册