diff --git a/CMakeLists.txt b/CMakeLists.txt index 0ec65bac84b0b0d89123473a8941f80c90f1b339..399bf50748ea0687c19c0e11b5ff315a6dc032ce 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,7 @@ include(simd) ################################ Configurations ####################################### option(WITH_GPU "Compile PaddlePaddle with NVIDIA GPU" ${CUDA_FOUND}) +option(WITH_AMD_GPU "Compile PaddlePaddle with AMD GPU" OFF) option(WITH_AVX "Compile PaddlePaddle with AVX intrinsics" ${AVX_FOUND}) option(WITH_MKL "Compile PaddlePaddle with MKL support." ${AVX_FOUND}) option(WITH_DSO "Compile PaddlePaddle with dynamic linked CUDA" ON) @@ -69,6 +70,9 @@ if(NOT CMAKE_BUILD_TYPE) FORCE) endif() +if(WITH_AMD_GPU) +endif() + if(ANDROID OR IOS) if(ANDROID) if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16") @@ -180,6 +184,11 @@ if(WITH_GPU) include(cuda) endif(WITH_GPU) +if(WITH_AMD_GPU) + find_package(HIP) + include(hip) +endif(WITH_AMD_GPU) + if(WITH_MKLML) list(APPEND EXTERNAL_LIBS ${MKLML_IOMP_LIB}) endif() diff --git a/cmake/configure.cmake b/cmake/configure.cmake index 0f76f55270592c5625a9624b33f4c0f82efdc627..f726405c4773994f6ca6509e5218750805b03995 100644 --- a/cmake/configure.cmake +++ b/cmake/configure.cmake @@ -57,11 +57,7 @@ if(NOT WITH_GOLANG) add_definitions(-DPADDLE_WITHOUT_GOLANG) endif(NOT WITH_GOLANG) -if(NOT WITH_GPU) - add_definitions(-DHPPL_STUB_FUNC) - - list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) -else() +if(WITH_GPU) add_definitions(-DPADDLE_WITH_CUDA) FIND_PACKAGE(CUDA REQUIRED) @@ -84,7 +80,14 @@ else() # Include cuda and cudnn include_directories(${CUDNN_INCLUDE_DIR}) include_directories(${CUDA_TOOLKIT_INCLUDE}) -endif(NOT WITH_GPU) +elseif(WITH_AMD_GPU) + add_definitions(-DPADDLE_WITH_HIP) + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -D__HIP_PLATFORM_HCC__") +else() + add_definitions(-DHPPL_STUB_FUNC) + list(APPEND CMAKE_CXX_SOURCE_FILE_EXTENSIONS cu) +endif() if (WITH_MKLML AND MKLML_IOMP_LIB) message(STATUS "Enable Intel OpenMP with ${MKLML_IOMP_LIB}") diff --git a/cmake/external/eigen.cmake b/cmake/external/eigen.cmake index 6a701e076c95372f903a09d35d4208ee73bd584c..5d88c5a0b091c9a22ad1dff41b9e76d406fcbc76 100644 --- a/cmake/external/eigen.cmake +++ b/cmake/external/eigen.cmake @@ -1,21 +1,36 @@ INCLUDE(ExternalProject) SET(EIGEN_SOURCE_DIR ${THIRD_PARTY_PATH}/eigen3) -SET(EIGEN_INCLUDE_DIR ${EIGEN_SOURCE_DIR}/src/extern_eigen3) -INCLUDE_DIRECTORIES(${EIGEN_INCLUDE_DIR}) -ExternalProject_Add( - extern_eigen3 - ${EXTERNAL_PROJECT_LOG_ARGS} - GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" - GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 - PREFIX ${EIGEN_SOURCE_DIR} - UPDATE_COMMAND "" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) +INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3) + +if(WITH_AMD_GPU) + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/sabreshao/hipeigen.git" + GIT_TAG 0cba03ff9f8f9f70bbd92ac5857b031aa8fed6f9 + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +else() + ExternalProject_Add( + extern_eigen3 + ${EXTERNAL_PROJECT_LOG_ARGS} + GIT_REPOSITORY "https://github.com/RLovelett/eigen.git" + GIT_TAG 70661066beef694cadf6c304d0d07e0758825c10 + PREFIX ${EIGEN_SOURCE_DIR} + UPDATE_COMMAND "" + CONFIGURE_COMMAND "" + BUILD_COMMAND "" + INSTALL_COMMAND "" + TEST_COMMAND "" + ) +endif() if (${CMAKE_VERSION} VERSION_LESS "3.3.0") set(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/eigen3_dummy.c) diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 471e3929069d0d28105404b4f0f6baa303faf0e0..c749c97f13649fe8432091414b56f7d0ea8ace8b 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -317,6 +317,82 @@ function(nv_test TARGET_NAME) endif() endfunction(nv_test) +function(hip_library TARGET_NAME) + if (WITH_AMD_GPU) + set(options STATIC static SHARED shared) + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(_sources ${hip_library_SRCS}) + HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options}) + if(_source_files) + list(REMOVE_ITEM _sources ${_source_files}) + endif() + if(hip_library_SRCS) + if (hip_library_SHARED OR hip_library_shared) # build *.so + add_library(${TARGET_NAME} SHARED ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) + else() + add_library(${TARGET_NAME} STATIC ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE CXX) + target_link_libraries(${TARGET_NAME} /opt/rocm/hip/lib/libhip_hcc.so /opt/rocm/hip/lib/libhip_device.a) + find_fluid_modules(${TARGET_NAME}) + endif() + if (hip_library_DEPS) + add_dependencies(${TARGET_NAME} ${hip_library_DEPS}) + target_link_libraries(${TARGET_NAME} ${hip_library_DEPS}) + endif() + # cpplint code style + foreach(source_file ${hip_library_SRCS}) + string(REGEX REPLACE "\\.[^.]*$" "" source ${source_file}) + if(EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + list(APPEND hip_library_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/${source}.h) + endif() + endforeach() + add_style_check_target(${TARGET_NAME} ${hip_library_SRCS} ${hip_library_HEADERS}) + else(hip_library_SRCS) + if (hip_library_DEPS) + merge_static_libs(${TARGET_NAME} ${hip_library_DEPS}) + else() + message(FATAL "Please specify source file or library in nv_library.") + endif() + endif(hip_library_SRCS) + endif() +endfunction(hip_library) + +function(hip_binary TARGET_NAME) + if (WITH_AMD_GPU) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_binary "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + hip_add_executable(${TARGET_NAME} ${hip_binary_SRCS}) + if(hip_binary_DEPS) + target_link_libraries(${TARGET_NAME} ${hip_binary_DEPS}) + add_dependencies(${TARGET_NAME} ${hip_binary_DEPS}) + endif() + endif() +endfunction(hip_binary) + +function(hip_test TARGET_NAME) + if (WITH_AMD_GPU AND WITH_TESTING) + set(options "") + set(oneValueArgs "") + set(multiValueArgs SRCS DEPS) + cmake_parse_arguments(hip_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + set(_sources ${hip_test_SRCS}) + HIP_PREPARE_TARGET_COMMANDS(${TARGET_NAME} OBJ _generated_files _source_files ${_sources} HIPCC_OPTIONS ${_hipcc_options} HCC_OPTIONS ${_hcc_options} NVCC_OPTIONS ${_nvcc_options}) + if(_source_files) + list(REMOVE_ITEM _sources ${_source_files}) + endif() + add_executable(${TARGET_NAME} ${_cmake_options} ${_generated_files} ${_sources}) + set_target_properties(${TARGET_NAME} PROPERTIES LINKER_LANGUAGE HIP) + target_link_libraries(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_dependencies(${TARGET_NAME} ${hip_test_DEPS} paddle_gtest_main paddle_memory gtest gflags) + add_test(${TARGET_NAME} ${TARGET_NAME}) + endif() +endfunction(hip_test) + function(go_library TARGET_NAME) set(options STATIC static SHARED shared) set(oneValueArgs "") diff --git a/cmake/hip.cmake b/cmake/hip.cmake new file mode 100644 index 0000000000000000000000000000000000000000..cd880603a728402384ecd8a044711bffea2c3daf --- /dev/null +++ b/cmake/hip.cmake @@ -0,0 +1,46 @@ +if(NOT WITH_AMD_GPU) + return() +endif() + +include_directories("/opt/rocm/include") +include_directories("/opt/rocm/hipblas/include") +include_directories("/opt/rocm/hiprand/include") +include_directories("/opt/rocm/rocrand/include") +include_directories("/opt/rocm/rccl/include") +include_directories("/opt/rocm/thrust") + +list(APPEND EXTERNAL_LIBS "-L/opt/rocm/lib/ -lhip_hcc") + +set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -fPIC -DPADDLE_WITH_HIP -std=c++14" ) + +if(WITH_DSO) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_USE_DSO") +endif(WITH_DSO) + +if(WITH_DOUBLE) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_TYPE_DOUBLE") +endif(WITH_DOUBLE) + +if(WITH_TESTING) + set(HIP_HCC_FLAGS "${HIP_HCC_FLAGS} -DPADDLE_WITH_TESTING") +endif(WITH_TESTING) + +if(CMAKE_BUILD_TYPE STREQUAL "Debug") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_DEBUG}) +elseif(CMAKE_BUILD_TYPE STREQUAL "Release") +# Disable optimization since one eigen symbol will be removed in math_function.cu + #list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELEASE}) +elseif(CMAKE_BUILD_TYPE STREQUAL "RelWithDebInfo") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_RELWITHDEBINFO}) +elseif(CMAKE_BUILD_TYPE STREQUAL "MinSizeRel") + list(APPEND HIP_HCC_FLAGS ${CMAKE_CXX_FLAGS_MINSIZEREL}) +endif() + +if("x${HCC_HOME}" STREQUAL "x") + set(HCC_HOME "/opt/rocm/hcc") +endif() + +set(CMAKE_HIP_LINK_EXECUTABLE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o ") +set(CMAKE_HIP_CREATE_SHARED_LIBRARY "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") +set(CMAKE_HIP_CREATE_SHARED_MODULE "${HIP_HIPCC_CMAKE_LINKER_HELPER} ${HCC_HOME} -o -shared") + diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index d30124d4a3b89b802a4abaae07a33b76526f163d..26d1dab1e95a5f28b89caba5f00c6e77596e36a8 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -76,6 +76,9 @@ function(op_library TARGET) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cudnn_cu_cc_srcs} ${mkldnn_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_cc_srcs} ${miopen_cu_cc_srcs} ${mkldnn_cc_srcs} ${hip_srcs} DEPS + ${op_library_DEPS} ${op_common_deps}) else() cc_library(${TARGET} SRCS ${cc_srcs} ${mkldnn_cc_srcs} DEPS ${op_library_DEPS} ${op_common_deps}) diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index fba1612d10f0494f4ab06fabdd0e799a74dafd53..1cac62472cada3bf129ed7157d196eefc8ed644a 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -6,6 +6,7 @@ function(math_library TARGET) # But it handle split GPU/CPU code and link some common library. set(cc_srcs) set(cu_srcs) + set(hip_srcs) set(math_common_deps device_context framework_proto) set(multiValueArgs DEPS) cmake_parse_arguments(math_library "${options}" "${oneValueArgs}" @@ -17,10 +18,15 @@ function(math_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.hip.cu) + list(APPEND hip_srcs ${TARGET}.hip.cu) + endif() list(LENGTH cc_srcs cc_srcs_len) if (WITH_GPU) nv_library(${TARGET} SRCS ${cc_srcs} ${cu_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) + elseif (WITH_AMD_GPU) + hip_library(${TARGET} SRCS ${cc_srcs} ${hip_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) elseif(${cc_srcs_len} GREATER 0) cc_library(${TARGET} SRCS ${cc_srcs} DEPS ${math_library_DEPS} ${math_common_deps}) endif() diff --git a/paddle/fluid/operators/math/concat.hip.cu b/paddle/fluid/operators/math/concat.hip.cu new file mode 100644 index 0000000000000000000000000000000000000000..91efd8ea57d628a0ffeb8d406779b6ae3bcc9571 --- /dev/null +++ b/paddle/fluid/operators/math/concat.hip.cu @@ -0,0 +1,281 @@ +/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "hip/hip_runtime.h" +#include "paddle/fluid/framework/mixed_vector.h" +#include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__device__ T upper_bound(const T* first, T count, T val) { + const T* orig = first; + const T* it = nullptr; + T step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (!(val < *it)) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + return first - orig; +} + +template +__global__ void KernelConcat(T** inputs, const int* input_cols, int col_size, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(input_cols, col_size, tid_x) - 1; + + int curr_offset = input_cols[segment]; + int curr_segment = segment; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = input_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* input_ptr = inputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * segment_width + local_col]; + } +} + +template +__global__ void KernelConcat(T** inputs, const int input_col, + const int output_rows, const int output_cols, + T* output) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < output_cols; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* input_ptr = inputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < output_rows; tid_y += blockDim.y * gridDim.y) { + output[tid_y * output_cols + tid_x] = + input_ptr[tid_y * input_col + in_offset]; + } + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int* output_cols, + int col_size, T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + int segment = upper_bound(output_cols, col_size, tid_x) - 1; + int curr_offset = output_cols[segment]; + int curr_segment = segment; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + T curr_col_offset; + while ((curr_col_offset = output_cols[curr_segment + 1]) <= tid_x) { + curr_offset = curr_col_offset; + ++curr_segment; + } + + int local_col = tid_x - curr_offset; + int segment_width = curr_col_offset - curr_offset; + T* output_ptr = outputs[curr_segment]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * segment_width + local_col] = + input[tid_y * input_col + tid_x]; + } +} + +template +__global__ void KernelConcatGrad(const T* input, const int input_row, + const int input_col, const int output_cols, + T** outputs) { + int tid_x = blockIdx.x * blockDim.x + threadIdx.x; + double inv_input_col = 1.0 / input_col; + for (; tid_x < input_col; tid_x += blockDim.x * gridDim.x) { + int split = tid_x * inv_input_col; + int in_offset = tid_x - split * input_col; + T* output_ptr = outputs[split]; + int tid_y = blockIdx.y * blockDim.y + threadIdx.y; + for (; tid_y < input_row; tid_y += blockDim.y * gridDim.y) + output_ptr[tid_y * output_cols + in_offset] = + input[tid_y * input_col + tid_x]; + } +} + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const std::vector& input, const int axis, + framework::Tensor* output) { + // TODO(zcd): Add input data validity checking + int num = input.size(); + int rows = 1; + auto dim_0 = input[0].dims(); + for (int i = 0; i < axis; ++i) { + rows *= dim_0[i]; + } + int cols = input[0].numel() / rows; + int out_rows = rows, out_cols = 0; + + framework::Vector inputs_data(num * sizeof(T*) / 2); + framework::Vector inputs_cols(num + 1); + inputs_cols[0] = 0; + T** inputs_ptr = reinterpret_cast(inputs_data.data()); + + bool sameShape = true; + for (int i = 0; i < num; ++i) { + int t_cols = input[i].numel() / rows; + if (sameShape) { + if (t_cols != cols) sameShape = false; + } + out_cols += t_cols; + inputs_cols[i + 1] = out_cols; + inputs_ptr[i] = const_cast(input[i].data()); + } + + T** ins_gpu = + reinterpret_cast(inputs_data.CUDAMutableData(context.GetPlace())); + const int* ins_col_gpu = inputs_cols.CUDAData(context.GetPlace()); + + // computation + // set the thread block and grid according to CurrentDeviceId + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (out_cols < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((out_cols + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((out_cols + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(out_rows / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), + ins_gpu, cols, out_rows, out_cols, output->data()); + } else { + hipLaunchKernelGGL((KernelConcat), dim3(grid_size), dim3(block_size), 0, context.stream(), + ins_gpu, ins_col_gpu, static_cast(inputs_cols.size()), out_rows, + out_cols, output->data()); + } + } +}; + +/* + * All tensors' dimension should be the same and the values of + * each dimension are the same, except the axis dimension. + */ +template +class ConcatGradFunctor { + public: + void operator()(const platform::CUDADeviceContext& context, + const framework::Tensor& input, const int axis, + std::vector& outputs) { + // TODO(zcd): Add input data validity checking + int num = outputs.size(); + int input_row = 1; + auto dim_0 = outputs[0].dims(); + for (int i = 0; i < axis; ++i) { + input_row *= dim_0[i]; + } + + int output_col_0 = outputs[0].numel() / input_row; + int input_col = 0; + bool sameShape = true; + + framework::Vector outputs_data(num * sizeof(T*) / 2); + framework::Vector outputs_cols(num + 1); + outputs_cols[0] = 0; + T** outputs_ptr = reinterpret_cast(outputs_data.data()); + + for (int i = 0; i < num; ++i) { + int t_col = outputs[i].numel() / input_row; + if (sameShape) { + if (t_col != output_col_0) sameShape = false; + } + input_col += t_col; + outputs_cols[i + 1] = input_col; + outputs_ptr[i] = outputs[i].data(); + } + + T** outs_gpu = + reinterpret_cast(outputs_data.CUDAMutableData(context.GetPlace())); + const int* outs_col_gpu = outputs_cols.CUDAData(context.GetPlace()); + + // computation + const int kThreadsPerBlock = 1024; + int block_cols = kThreadsPerBlock; + if (input_col < kThreadsPerBlock) { // block_cols is aligned by 32. + block_cols = ((input_col + 31) >> 5) << 5; + } + int block_rows = kThreadsPerBlock / block_cols; + dim3 block_size = dim3(block_cols, block_rows, 1); + + int max_threads = context.GetMaxPhysicalThreadCount(); + int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); + + int grid_cols = + std::min((input_col + block_cols - 1) / block_cols, max_blocks); + int grid_rows = + std::min(max_blocks / grid_cols, std::max(input_row / block_rows, 1)); + dim3 grid_size = dim3(grid_cols, grid_rows, 1); + + if (sameShape) { + hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), + input.data(), input_row, input_col, output_col_0, outs_gpu); + } else { + hipLaunchKernelGGL((KernelConcatGrad), dim3(grid_size), dim3(block_size), 0, context.stream(), + input.data(), input_row, input_col, outs_col_gpu, + static_cast(outputs_cols.size()), outs_gpu); + } + } +}; + +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; +template class ConcatFunctor; + +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; +template class ConcatGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/CMakeLists.txt b/paddle/fluid/pybind/CMakeLists.txt index 8942b5c9430ffa4e499b0ad1d2b5acf6d18ec0ab..d523ad7f73df48515252e5fe6c36df13f794c0a2 100644 --- a/paddle/fluid/pybind/CMakeLists.txt +++ b/paddle/fluid/pybind/CMakeLists.txt @@ -1,9 +1,16 @@ if(WITH_PYTHON) - cc_library(paddle_pybind SHARED - SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc - DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method - ${GLOB_OP_LIB}) - if(NOT APPLE AND NOT ANDROID) - target_link_libraries(paddle_pybind rt) - endif(NOT APPLE AND NOT ANDROID) + if(WITH_AMD_GPU) + hip_library(paddle_pybind SHARED + SRCS pybind.cc exception.cc protobuf.cc const_value.cc + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + ${GLOB_OP_LIB}) + else() + cc_library(paddle_pybind SHARED + SRCS pybind.cc exception.cc protobuf.cc const_value.cc + DEPS pybind python backward proto_desc paddle_memory executor prune init profiler feed_fetch_method + ${GLOB_OP_LIB}) + if(NOT APPLE AND NOT ANDROID) + target_link_libraries(paddle_pybind rt) + endif(NOT APPLE AND NOT ANDROID) + endif(WITH_AMD_GPU) endif(WITH_PYTHON) diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh old mode 100644 new mode 100755 index 6be2bd8fad9e33cf4e1dcafdd6b8f39111bdbe88..02f2d7ba1244a2a90a37c22823e976b9c9619dca --- a/paddle/scripts/docker/build.sh +++ b/paddle/scripts/docker/build.sh @@ -37,6 +37,7 @@ function cmake_gen() { -DWITH_DSO=ON -DWITH_DOC=OFF -DWITH_GPU=${WITH_GPU:-OFF} + -DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} -DWITH_MKL=${WITH_MKL:-ON} -DWITH_AVX=${WITH_AVX:-OFF} @@ -50,6 +51,7 @@ function cmake_gen() { -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} -DWITH_TESTING=${WITH_TESTING:-ON} -DWITH_FAST_BUNDLE_TEST=ON + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake -DCMAKE_EXPORT_COMPILE_COMMANDS=ON ======================================== EOF @@ -62,6 +64,7 @@ EOF -DWITH_DSO=ON \ -DWITH_DOC=OFF \ -DWITH_GPU=${WITH_GPU:-OFF} \ + -DWITH_AMD_GPU=${WITH_AMD_GPU:-OFF} \ -DWITH_DISTRIBUTE=${WITH_DISTRIBUTE:-OFF} \ -DWITH_MKL=${WITH_MKL:-ON} \ -DWITH_AVX=${WITH_AVX:-OFF} \ @@ -74,6 +77,7 @@ EOF -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON} \ -DWITH_TESTING=${WITH_TESTING:-ON} \ -DWITH_FAST_BUNDLE_TEST=ON \ + -DCMAKE_MODULE_PATH=/opt/rocm/hip/cmake \ -DCMAKE_EXPORT_COMPILE_COMMANDS=ON }