未验证 提交 05d3b19b 编写于 作者: Z Zhaolong Xing 提交者: GitHub

lite cuda init: can run a simple model with leaky_relu (#1860)

* paddle lite cuda init
can run model with leaky_relu

* add the missing file.
test=develop
上级 9ef59f74
......@@ -159,8 +159,13 @@ include(external/mkldnn) # download, build, install mkldnn
include(external/eigen) # download eigen3
include(external/xxhash) # download install xxhash needed for x86 jit
include(cudnn)
include(configure) # add paddle env configuration
if(LITE_WITH_CUDA)
include(cuda)
endif()
include(generic) # simplify cmake module
include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
......
......@@ -61,8 +61,8 @@ if(WITH_PSLIB)
add_definitions(-DPADDLE_WITH_PSLIB)
endif()
if(LITE_WITH_GPU)
add_definitions(-DPADDLE_WITH_CUDA)
if(LITE_WITH_CUDA)
add_definitions(-DLITE_WITH_CUDA)
add_definitions(-DEIGEN_USE_GPU)
FIND_PACKAGE(CUDA REQUIRED)
......@@ -86,36 +86,6 @@ if(LITE_WITH_GPU)
include_directories(${CUDNN_INCLUDE_DIR})
include_directories(${CUDA_TOOLKIT_INCLUDE})
if(TENSORRT_FOUND)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 8)
message(FATAL_ERROR "TensorRT needs CUDA >= 8.0 to compile")
endif()
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
message(FATAL_ERROR "TensorRT needs CUDNN >= 7.0 to compile")
endif()
if(${TENSORRT_MAJOR_VERSION} VERSION_LESS 4)
message(FATAL_ERROR "Paddle needs TensorRT >= 4.0 to compile")
endif()
include_directories(${TENSORRT_INCLUDE_DIR})
endif()
if(WITH_ANAKIN)
if(${CUDA_VERSION_MAJOR} VERSION_LESS 8)
message(WARNING "Anakin needs CUDA >= 8.0 to compile. Force WITH_ANAKIN=OFF")
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDA >= 8.0." FORCE)
endif()
if(${CUDNN_MAJOR_VERSION} VERSION_LESS 7)
message(WARNING "Anakin needs CUDNN >= 7.0 to compile. Force WITH_ANAKIN=OFF")
set(WITH_ANAKIN OFF CACHE STRING "Anakin is valid only when CUDNN >= 7.0." FORCE)
endif()
add_definitions(-DWITH_ANAKIN)
endif()
if(WITH_ANAKIN)
# NOTICE(minqiyang): the end slash is important because $CUDNN_INCLUDE_DIR
# is a softlink to real cudnn.h directory
set(ENV{CUDNN_INCLUDE_DIR} "${CUDNN_INCLUDE_DIR}/")
get_filename_component(CUDNN_LIBRARY_DIR ${CUDNN_LIBRARY} DIRECTORY)
set(ENV{CUDNN_LIBRARY} ${CUDNN_LIBRARY_DIR})
endif()
elseif(WITH_AMD_GPU)
add_definitions(-DPADDLE_WITH_HIP)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -D__HIP_PLATFORM_HCC__")
......@@ -168,10 +138,6 @@ endif()
# for lite
# TODO(Superjomn) not work fine with the option
if (LITE_WITH_CUDA)
add_definitions("-DLITE_WITH_CUDA")
endif()
if (LITE_WITH_X86)
add_definitions("-DLITE_WITH_X86")
endif()
......
if(NOT WITH_GPU)
if(NOT LITE_WITH_CUDA)
return()
endif()
......
if(NOT WITH_GPU)
if(NOT LITE_WITH_CUDA)
return()
endif()
......@@ -34,10 +34,7 @@ list(APPEND CUDNN_CHECK_LIBRARY_DIRS
${CUDA_TOOLKIT_ROOT_DIR}
${CUDA_TOOLKIT_ROOT_DIR}/lib/x64
)
set(CUDNN_LIB_NAME "")
if (LINUX)
set(CUDNN_LIB_NAME "libcudnn.so")
endif(LINUX)
if(WIN32)
# only support cudnn7
......
......@@ -387,8 +387,8 @@ function(cc_test TARGET_NAME)
endif()
endif(WIN32)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main lod_tensor memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
target_link_libraries(${TARGET_NAME} ${cc_test_DEPS} ${os_dependency_modules} paddle_gtest_main memory gtest gflags glog)
add_dependencies(${TARGET_NAME} ${cc_test_DEPS} paddle_gtest_main memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(NAME ${TARGET_NAME}
COMMAND ${TARGET_NAME} ${cc_test_ARGS}
......@@ -447,7 +447,7 @@ function(_lite_cc_test args)
endfunction()
function(nv_library TARGET_NAME)
if (WITH_GPU)
if (LITE_WITH_CUDA)
set(options STATIC static SHARED shared)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
......@@ -481,7 +481,7 @@ function(nv_library TARGET_NAME)
endfunction(nv_library)
function(nv_binary TARGET_NAME)
if (WITH_GPU)
if (LITE_WITH_CUDA)
set(options "")
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
......@@ -496,15 +496,15 @@ function(nv_binary TARGET_NAME)
endfunction(nv_binary)
function(nv_test TARGET_NAME)
if (WITH_GPU AND WITH_TESTING)
if (LITE_WITH_CUDA AND WITH_TESTING)
set(options SERIAL)
set(oneValueArgs "")
set(multiValueArgs SRCS DEPS)
cmake_parse_arguments(nv_test "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
cuda_add_executable(${TARGET_NAME} ${nv_test_SRCS})
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog ${os_dependency_modules})
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main lod_tensor memory gtest gflags glog)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog ${os_dependency_modules})
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} paddle_gtest_main memory gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
if (nv_test_SERIAL)
......
......@@ -8,7 +8,7 @@ if (WITH_TESTING)
lite_cc_library(lite_api_test_helper SRCS lite_api_test_helper.cc
DEPS scope optimizer target_wrapper_host model_parser program
${ops} ${host_kernels}
CUDA_DEPS kernels_cuda
CUDA_DEPS ${cuda_kernels}
X86_DEPS ${x86_kernels})
endif()
if(LITE_WITH_FPGA)
......@@ -27,11 +27,6 @@ message(STATUS "get FPGA kernels ${fpga_kernels}")
if (NOT LITE_ON_TINY_PUBLISH)
set(cxx_api_deps
scope optimizer target_wrapper_host model_parser program)
if(LITE_WITH_CUDA)
set(cxx_api_deps ${cxx_api_deps} kernels_cuda)
lite_cc_library(cxx_api_cuda SRCS cxx_api.cc DEPS ${cxx_api_deps} target_wrapper_cuda)
nv_test(test_cxx_api_cuda SRCS cxx_api_test.cc DEPS cxx_api_cuda)
endif()
lite_cc_library(cxx_api
SRCS cxx_api.cc
DEPS ${cxx_api_deps} ${ops} ${host_kernels} program
......@@ -51,7 +46,7 @@ endif()
lite_cc_library(light_api SRCS light_api.cc
DEPS scope target_wrapper_host model_parser
${light_api_deps} ${ops} ${host_kernels} program
CUDA_DEPS target_wrapper_cuda
CUDA_DEPS ${cuda_kernels}
X86_DEPS ${x86_kernels}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass
......
......@@ -145,6 +145,7 @@ USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy_once, kCUDA, kAny, kAny, device_to_host);
USE_LITE_KERNEL(leaky_relu, kCUDA, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_OPENCL
......
......@@ -22,7 +22,7 @@
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
......
......@@ -20,7 +20,7 @@
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
......
......@@ -15,7 +15,7 @@
#pragma once
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
#include "lite/core/tensor.h"
namespace paddle {
......
......@@ -14,7 +14,7 @@
#pragma once
#include <cmath>
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
......
......@@ -16,7 +16,7 @@
#include <cmath>
#include "lite/core/context.h"
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
#include "lite/core/tensor.h"
namespace paddle {
......
......@@ -17,7 +17,7 @@
#include <cmath>
#include "lite/arm/math/packed_sgemm.h"
#include "lite/core/context.h"
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
......
......@@ -30,12 +30,12 @@ lite_cc_library(types SRCS types.cc)
endif()
lite_cc_library(op_registry SRCS op_registry.cc DEPS kernel)
lite_cc_library(scope SRCS scope.cc DEPS tensor)
lite_cc_library(cpu_info SRCS cpu_info.cc DEPS tensor)
lite_cc_library(device_info SRCS device_info.cc DEPS tensor)
if (LITE_WITH_ARM)
lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info CL_DEPS cl_context gflags NPU_DEPS ${npu_ddk_libs})
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context gflags NPU_DEPS ${npu_ddk_libs})
else()
lite_cc_library(context SRCS context.cc DEPS tensor any cpu_info eigen3 CL_DEPS cl_context gflags)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context gflags)
endif()
#----------------------------------------------- NOT CHANGE -----------------------------------------------
......
......@@ -35,7 +35,7 @@
#include <string>
#include <utility>
#include <vector>
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
#include "lite/utils/all.h"
......@@ -153,11 +153,39 @@ class Context<TargetType::kFPGA> {
template <>
class Context<TargetType::kCUDA> {
public:
typename Env<TargetType::kCUDA>::Devs& devs =
Env<TargetType::kCUDA>::Global();
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
}
void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) {
CHECK_GT(devs.size(), 0)
<< "Env is not initialized or current target is not exit!";
if (dev_id >= devs.size()) {
LOG(WARNING) << "device index exceeds the number of devices, set to "
"default device(0)!";
device_id_ = 0;
} else {
device_id_ = dev_id;
}
if (io_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "data stream index exceeds the maximum stream number, "
"set to default stream(0)!";
io_stream_id = 0;
}
if (exec_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "exec stream index exceeds the maximum stream number, "
"set to default stream(0)!";
exec_stream_id = 0;
}
exec_stream_ = devs[dev_id].exec_streams()[exec_stream_id];
io_stream_ = devs[dev_id].io_streams()[io_stream_id];
exec_stream_id_ = exec_stream_id;
io_stream_id_ = io_stream_id;
}
void CopySharedTo(CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
......@@ -190,7 +218,10 @@ class Context<TargetType::kCUDA> {
std::string name() const { return "CUDAContext"; }
private:
int device_id_;
// overall information
int exec_stream_id_;
int io_stream_id_;
cudaStream_t exec_stream_;
cudaStream_t io_stream_;
......@@ -292,10 +323,13 @@ class ContextScheduler {
break;
#endif
#ifdef LITE_WITH_CUDA
case TARGET(kCUDA):
case TARGET(kCUDA): {
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
auto& context = ctx->As<CUDAContext>();
context.Init(dev_id);
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopySharedTo(
&ctx->As<CUDAContext>());
break;
&context);
} break;
#endif
#ifdef LITE_WITH_ARM
case TARGET(kARM):
......
......@@ -50,7 +50,7 @@
#include <algorithm>
#include <limits>
#include "lite/core/cpu_info.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
......@@ -1114,5 +1114,39 @@ bool DeviceInfo::ExtendWorkspace(int size) {
#endif // LITE_WITH_ARM
#ifdef LITE_WITH_CUDA
void Device<TARGET(kCUDA)>::Init() {
GetInfo();
CreateStream();
}
void Device<TARGET(kCUDA)>::GetInfo() {
cudaGetDeviceProperties(&device_prop_, idx_);
cudaRuntimeGetVersion(&runtime_version_);
sm_version_ = (device_prop_.major << 8 | device_prop_.minor);
has_hmma_ =
(sm_version_ == 0x0700 || sm_version_ == 0x0702 || sm_version_ == 0x0705);
has_fp16_ = (sm_version_ == 0x0602 || sm_version_ == 0x0600 ||
sm_version_ == 0x0503 || has_hmma_);
has_imma_ = (sm_version_ == 0x0702 || sm_version_ == 0x0705);
has_int8_ = (sm_version_ == 0x0601 || sm_version_ == 0x0700 || has_imma_);
}
void Device<TARGET(kCUDA)>::CreateStream() {
exec_stream_.clear();
io_stream_.clear();
for (int i = 0; i < max_stream_; i++) {
cudaStream_t exec_stream;
cudaStream_t io_stream;
cudaStreamCreate(&exec_stream);
cudaStreamCreate(&io_stream);
exec_stream_.push_back(exec_stream);
io_stream_.push_back(io_stream);
}
}
#endif
} // namespace lite
} // namespace paddle
......@@ -122,5 +122,88 @@ class DeviceInfo {
#endif // LITE_WITH_ARM
template <TargetType Type>
class Device;
template <TargetType Type>
class Env {
public:
typedef TargetWrapper<Type> API;
typedef std::vector<Device<Type>> Devs;
static Devs& Global() {
static Devs* devs = new Devs();
return *devs;
}
static void Init(int max_stream = 4) {
Devs& devs = Global();
if (devs.size() > 0) {
return;
}
int count = 0;
// Get device count
count = API::num_devices();
if (count == 0) {
CHECK(false) << "No device found!";
} else {
LOG(INFO) << "Found " << count << " device(s)";
}
// create all device
for (int i = 0; i < count; i++) {
auto dev = Device<Type>(i, max_stream);
dev.Init();
devs.push_back(dev);
}
LOG(INFO) << "dev size = " << devs.size();
}
};
#ifdef LITE_WITH_CUDA
template <>
class Device<TARGET(kCUDA)> {
public:
Device(int dev_id, int max_stream = 1)
: idx_(dev_id), max_stream_(max_stream) {}
void Init();
int id() { return idx_; }
int max_stream() { return max_stream_; }
int SetId(int idx) { idx_ = idx; }
std::string name() { return device_prop_.name; }
int core_num() { return device_prop_.multiProcessorCount; }
float max_memory() { return device_prop_.totalGlobalMem / 1048576.; }
std::vector<cudaStream_t> exec_streams() { return exec_stream_; }
std::vector<cudaStream_t> io_streams() { return io_stream_; }
int sm_version() { return sm_version_; }
bool has_fp16() { return has_fp16_; }
bool has_int8() { return has_fp16_; }
bool has_hmma() { return has_fp16_; }
bool has_imma() { return has_fp16_; }
int runtime_version() { return runtime_version_; }
private:
void CreateStream();
void GetInfo();
private:
int max_stream_;
int idx_{0};
cudaDeviceProp device_prop_;
std::string device_name_;
float max_memory_;
int sm_version_;
bool has_fp16_;
bool has_int8_;
bool has_hmma_;
bool has_imma_;
int runtime_version_;
std::vector<cudaStream_t> exec_stream_;
std::vector<cudaStream_t> io_stream_;
};
template class Env<TARGET(kCUDA)>;
#endif
} // namespace lite
} // namespace paddle
......@@ -27,8 +27,7 @@ void* TargetMalloc(TargetType target, size_t size) {
break;
#ifdef LITE_WITH_CUDA
case TargetType::kCUDA:
data =
TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>::Malloc(size);
data = TargetWrapper<TARGET(kCUDA)>::Malloc(size);
break;
#endif // LITE_WITH_CUDA
#ifdef LITE_WITH_OPENCL
......
......@@ -21,6 +21,10 @@
#include "lite/opencl/target_wrapper.h"
#endif // LITE_WITH_OPENCL
#ifdef LITE_WITH_CUDA
#include "lite/cuda/target_wrapper.h"
#endif // LITE_WITH_CUDA
namespace paddle {
namespace lite {
......@@ -36,7 +40,7 @@ void LITE_API TargetFree(TargetType target, void* data);
void TargetCopy(TargetType target, void* dst, const void* src, size_t size);
template <TargetType Target>
void CopySync(void* dst, void* src, size_t size, IoDirection dir) {
void CopySync(void* dst, const void* src, size_t size, IoDirection dir) {
switch (Target) {
case TARGET(kX86):
case TARGET(kHost):
......
......@@ -26,7 +26,7 @@ TEST(memory, test) {
#ifdef LITE_WITH_CUDA
auto* buf_cuda = TargetMalloc(TARGET(kCUDA), 10);
ASSERT_TRUE(buf_cuda);
TargetFree(Target(kCUDA), buf_cuda);
TargetFree(TARGET(kCUDA), buf_cuda);
#endif
}
......
......@@ -105,6 +105,7 @@ void PatternMatcher::operator()(SSAGraph *graph,
ValidateByNodeRole(&subgraphs);
if (subgraphs.empty()) return;
LOG(INFO) << "detected " << subgraphs.size() << " subgraph";
int id = 0;
for (auto &g : subgraphs) {
VLOG(3) << "optimizing #" << id++ << " subgraph";
......
......@@ -165,44 +165,6 @@ class TargetWrapper<TARGET(kFPGA)> {
}
};
#endif
#ifdef LITE_WITH_CUDA
using TargetWrapperCuda =
TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
// This interface should be specified by each kind of target.
template <>
class TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t> {
public:
using stream_t = cudaStream_t;
using event_t = cudaEvent_t;
static size_t num_devices() { return 0; }
static size_t maximum_stream() { return 0; }
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst,
const void* src,
size_t size,
IoDirection dir);
static void MemcpyAsync(void* dst,
const void* src,
size_t size,
IoDirection dir,
const stream_t& stream);
};
#endif // LITE_WITH_CUDA
} // namespace lite
} // namespace paddle
......@@ -5,4 +5,3 @@ endif()
nv_library(target_wrapper_cuda SRCS target_wrapper.cc)
nv_library(cuda_blas SRCS blas.cc)
......@@ -17,20 +17,26 @@
namespace paddle {
namespace lite {
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
size_t TargetWrapperCuda::num_devices() {
int count = 0;
cudaGetDeviceCount(&count);
return count;
}
void* TargetW::Malloc(size_t size) {
void* TargetWrapperCuda::Malloc(size_t size) {
void* ptr{};
CHECK_EQ(cudaSuccess, cudaMalloc(&ptr, size));
return ptr;
}
void TargetW::Free(void* ptr) { CHECK_EQ(cudaSuccess, cudaFree(ptr)); }
void TargetWrapperCuda::Free(void* ptr) {
CHECK_EQ(cudaSuccess, cudaFree(ptr));
}
void TargetW::MemcpySync(void* dst,
const void* src,
size_t size,
IoDirection dir) {
void TargetWrapperCuda::MemcpySync(void* dst,
const void* src,
size_t size,
IoDirection dir) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
......@@ -47,11 +53,11 @@ void TargetW::MemcpySync(void* dst,
}
}
void TargetW::MemcpyAsync(void* dst,
const void* src,
size_t size,
IoDirection dir,
const stream_t& stream) {
void TargetWrapperCuda::MemcpyAsync(void* dst,
const void* src,
size_t size,
IoDirection dir,
const stream_t& stream) {
switch (dir) {
case IoDirection::DtoD:
CHECK(cudaSuccess ==
......
......@@ -19,11 +19,46 @@
namespace paddle {
namespace lite {
namespace cuda {
using TargetWrap = TargetWrapper<TARGET(kHost)>;
using TargetWrapAsync = TargetWrapper<TARGET(kHost), cudaStream_t, cudaEvent_t>;
using TargetWrapperCuda = TargetWrapper<TARGET(kCUDA)>;
} // namespace cuda
template <>
class TargetWrapper<TARGET(kCUDA)> {
public:
using stream_t = cudaStream_t;
using event_t = cudaEvent_t;
static size_t num_devices();
static size_t maximum_stream() { return 0; }
static size_t GetCurDevice() {
int dev_id;
cudaGetDevice(&dev_id);
return dev_id;
}
static void CreateStream(stream_t* stream) {}
static void DestroyStream(const stream_t& stream) {}
static void CreateEvent(event_t* event) {}
static void DestroyEvent(const event_t& event) {}
static void RecordEvent(const event_t& event) {}
static void SyncEvent(const event_t& event) {}
static void StreamSync(const stream_t& stream) {}
static void* Malloc(size_t size);
static void Free(void* ptr);
static void MemcpySync(void* dst,
const void* src,
size_t size,
IoDirection dir);
static void MemcpyAsync(void* dst,
const void* src,
size_t size,
IoDirection dir,
const stream_t& stream);
};
} // namespace lite
} // namespace paddle
......@@ -4,10 +4,16 @@ endif()
message(STATUS "compile with lite CUDA kernels")
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS tensor)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS tensor)
nv_library(mul_compute_cuda SRCS mul_compute.cc DEPS ${lite_kernel_deps} context)
lite_cc_library(io_copy_compute_cuda SRCS io_copy_compute.cc DEPS ${lite_kernel_deps})
nv_library(leaky_relu_compute_cuda SRCS leaky_relu_compute.cu DEPS ${lite_kernel_deps})
nv_library(kernels_cuda DEPS mul_compute_cuda io_copy_compute_cuda cuda_blas)
lite_cc_test(leaky_relu_compute_cuda_test SRCS leaky_relu_compute_test.cc DEPS leaky_relu_compute_cuda)
set(cuda_kernels
mul_compute_cuda
io_copy_compute_cuda
leaky_relu_compute_cuda
)
set(cuda_kernels "${cuda_kernels}" CACHE GLOBAL "cuda kernels")
......@@ -21,7 +21,7 @@ namespace lite {
namespace kernels {
namespace cuda {
using TargetW = TargetWrapper<TARGET(kCUDA), cudaStream_t, cudaEvent_t>;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
// Host to CUDA memory.
void CopyFromHostSync(void* target, const void* source, size_t size) {
......@@ -89,7 +89,6 @@ class IoCopyCudaToHostCompute
auto& param = Param<operators::IoCopyParam>();
CHECK(param.x->target() == TARGET(kCUDA));
auto mem_size = param.x->memory_size();
LOG(INFO) << "copy size " << mem_size;
auto* data = param.y->mutable_data(TARGET(kHost), mem_size);
CopyToHostSync(data, param.x->raw_data(), mem_size);
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/leaky_relu_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__global__ void LeakyReluKernel(const int num,
const T alpha,
const T* input,
T* output) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < num) {
#if __CUDA_ARCH__ >= 350
output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
: __ldg(input + index) * alpha;
#else
output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
}
}
void LeakyReluCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int num = static_cast<int>(param.X->numel());
float alpha = param.Leaky_relu_alpha;
auto input = param.X->data<float>();
auto output = param.Out->mutable_data<float>(TARGET(kCUDA));
int threads = 1024;
int blocks = (num + threads - 1) / threads;
LeakyReluKernel<<<blocks, threads, 0, stream>>>(num, alpha, input, output);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(leaky_relu,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::LeakyReluCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class LeakyReluCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~LeakyReluCompute() = default;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/leaky_relu_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
TEST(leaky_relu, normal) {
LeakyReluCompute leaky_relu_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::ActivationParam param;
Tensor x, y, x_cpu, y_cpu;
int h = 3, w = 3;
y.Resize({h, w});
x_cpu.Resize({h, w});
y_cpu.Resize({h, w});
auto* x_data = x.mutable_data<float>(TARGET(kCUDA));
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* y_cpu_data = x_cpu.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); i++) {
x_cpu_data[i] = i - 5.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
param.Leaky_relu_alpha = 10.0;
leaky_relu_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
leaky_relu_kernel.SetContext(std::move(ctx));
leaky_relu_kernel.Launch();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
for (int i = 0; i < y.numel(); i++) {
LOG(INFO) << y_cpu_data[i];
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/types.h"
#include "lite/cuda/blas.h"
......@@ -53,7 +54,7 @@ class MulCompute : public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
void Run() override {
CHECK(ctx_) << "running context should be set first";
auto& context = ctx_->As<CUDAContext>();
auto& context = this->ctx_->template As<CUDAContext>();
CHECK(context.cublas_fp32()) << "blas should init first";
/*
auto& blas = *context.cublas_fp32();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册