提交 0978d339 编写于 作者: W wanghaoshuang

Merge branch 'develop' of https://github.com/paddlepaddle/paddle into voc_dataset

......@@ -266,29 +266,6 @@ HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head));
}
/**
* \brief Check if a size and a stride create a Fortran order contiguous
* block of memory.
*/
template <int i>
HOST bool contiguous(const Dim<i>& size, const Dim<i>& stride, int mul = 1) {
if (product(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return (get<0>(stride) == contiguous_stride &&
contiguous(size.tail, stride.tail, mul * get<0>(size)));
}
///\cond HIDDEN
// Base case of contiguous, check the nth stride is the size of
// the prefix multiply of n-1 dims.
template <>
inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) {
if (get<0>(size) == 0) return true;
int contiguous_stride = get<0>(size) == 1 ? 0 : mul;
return get<0>(stride) == contiguous_stride;
}
///\endcond
/**
* \brief Compute exclusive prefix-multiply of a Dim.
*/
......@@ -306,31 +283,6 @@ HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
}
///\endcond
/**
* \brief Calculate strides of a contiguous array of the given size
*
* Sets the stride for any dimension with an extent of 1 to 0.
* \param size Dim object containing the size of the array.
* \param base The base stride to use.
* \return Dim object the same size as \p size with the strides.
*/
template <int i>
HOSTDEVICE Dim<i> contiguous_strides(const Dim<i>& size, int base = 1) {
int stride = size.head == 1 ? 0 : base;
return Dim<i>(stride, contiguous_strides(size.tail, base * size.head));
}
///\cond HIDDEN
// Base case of contiguous_strides
template <>
HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) {
int stride = size.head == 1 ? 0 : base;
return Dim<1>(stride);
}
///\endcond
/**
* Add two dimensions together
*/
......
......@@ -58,24 +58,6 @@ TEST(Dim, Equality) {
EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12);
// contiguous_strides
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 1, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 0);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(10, 10, 1));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 10);
EXPECT_EQ(paddle::framework::get<2>(c), 0);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(1, 10, 10));
EXPECT_EQ(paddle::framework::get<0>(c), 0);
EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 10);
c = paddle::framework::contiguous_strides(paddle::framework::Dim<3>(2, 3, 4));
EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 6);
// generate from an index
auto size = paddle::framework::make_dim(4, 5, 2);
c = paddle::framework::Dim<3>(14, size);
......@@ -101,16 +83,6 @@ TEST(Dim, Bool) {
EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c);
// contiguous check
int x = 4, y = 5, z = 2;
paddle::framework::Dim<3> sizef(x, y, z);
paddle::framework::Dim<3> stridea(1, x, x*y);
paddle::framework::Dim<3> strideb(2, 2*x, 2*x*y);
paddle::framework::Dim<3> stridec(1, x, 2*x*y);
EXPECT_TRUE(paddle::framework::contiguous(sizef, stridea));
EXPECT_FALSE(paddle::framework::contiguous(sizef, strideb));
EXPECT_FALSE(paddle::framework::contiguous(sizef, stridec));
}
TEST(Dim, Print) {
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <cstdint>
#include <memory>
#include <type_traits>
#include "paddle/framework/ddim.h"
......@@ -26,31 +27,65 @@ namespace framework {
class Tensor {
public:
Tensor() : offset_(0) {}
explicit Tensor(const DDim& dims) : dims_(dims), offset_(0) {}
template <typename T>
const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data.");
return static_cast<const T*>(holder_->Ptr());
PADDLE_ENFORCE(
holder_ != nullptr,
"Tenosr has not been initialized. Call Tensor::mutable_data first.");
return reinterpret_cast<const T*>(
reinterpret_cast<uintptr_t>(holder_->Ptr()) + offset_);
}
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, paddle::platform::Place place) {
dims_ = dims;
if (holder_ == nullptr ||
!(holder_->Place() ==
place) /* some versions of boost::variant don't have operator!= */
|| holder_->Size() < product(dims) * sizeof(T)) {
|| holder_->Size() < product(dims) * sizeof(T) + offset_) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
offset_ = 0;
}
return static_cast<T*>(holder_->Ptr());
return reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(holder_->Ptr()) +
offset_);
}
template <typename T, // must be POD types
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims) {
return mutable_data<T>(dims, paddle::platform::get_place());
void ShareDataFrom(const Tensor& src) {
PADDLE_ENFORCE(src.holder_ != nullptr,
"Can not share data from an uninitialized tensor.");
holder_ = src.holder_;
dims_ = src.dims_;
offset_ = src.offset_;
}
Tensor Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE(holder_ != nullptr,
"The sliced tenosr has not been initialized.");
PADDLE_ENFORCE(begin_idx >= 0 && end_idx <= dims_[0],
"Slice index is less than zero or out of bound.");
PADDLE_ENFORCE(begin_idx < end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE(dims_[0] != 1, "Can not slice a tensor with dims_[0] = 1.");
std::vector<int> d = vectorize(dims_);
int base = 1;
for (size_t i = 1; i < d.size(); ++i) {
base *= d[i];
}
Tensor dst;
dst.holder_ = holder_;
dst.dims_ = dims_;
dst.dims_[0] = end_idx - begin_idx;
dst.offset_ = offset_ + begin_idx * base * holder_->TypeSize();
return dst;
}
DDim dims() const { return dims_; }
private:
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
......@@ -59,6 +94,7 @@ class Tensor {
virtual void* Ptr() const = 0;
virtual paddle::platform::Place Place() const = 0;
virtual size_t Size() const = 0;
virtual size_t TypeSize() const = 0;
};
template <typename T>
......@@ -85,6 +121,7 @@ class Tensor {
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; }
virtual paddle::platform::Place Place() const { return place_; }
virtual size_t TypeSize() const { return sizeof(T); }
std::unique_ptr<T, Deleter> ptr_;
paddle::platform::Place place_; // record the place of ptr_.
......@@ -92,6 +129,8 @@ class Tensor {
};
std::shared_ptr<Placeholder> holder_; // holds the memory block if allocated.
DDim dims_;
size_t offset_; // marks the begin of tensor data area.
};
} // namespace framework
......
......@@ -15,15 +15,27 @@
#include <gtest/gtest.h>
#include <string>
TEST(Tensor, ASSERT) {
paddle::framework::Tensor cpu_tensor;
TEST(Tensor, Dims) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt(make_ddim({2, 3, 4}));
DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
}
}
TEST(Tensor, DataAssert) {
paddle::framework::Tensor src_tensor;
bool caught = false;
try {
const double* p __attribute__((unused)) = cpu_tensor.data<double>();
src_tensor.data<double>();
} catch (paddle::framework::EnforceNotMet err) {
caught = true;
std::string msg = "Tensor::data must be called after Tensor::mutable_data.";
std::string msg =
"Tenosr has not been initialized. Call Tensor::mutable_data first.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
......@@ -32,54 +44,138 @@ TEST(Tensor, ASSERT) {
ASSERT_TRUE(caught);
}
/* mutable_data() is not tested at present
/* following tests are not available at present
because Memory::Alloc() and Memory::Free() have not been ready.
TEST(Tensor, MutableData) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor cpu_tensor;
Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = cpu_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
EXPECT_NE(p1, nullptr);
// set cpu_tensor a new dim with large size
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = cpu_tensor.mutable_data<float>(make_ddim({3, 4}));
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set cpu_tensor a new dim with same size
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = cpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}));
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace());
EXPECT_EQ(p1, p2);
// set cpu_tensor a new dim with smaller size
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = cpu_tensor.mutable_data<float>(make_ddim({2, 2}));
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
EXPECT_EQ(p1, p2);
}
{
Tensor gpu_tensor;
Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = gpu_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
EXPECT_NE(p1, nullptr);
// set gpu_tensor a new dim with large size
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = gpu_tensor.mutable_data<float>(make_ddim({3, 4}));
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set gpu_tensor a new dim with same size
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = gpu_tensor.mutable_data<float>(make_ddim({2, 2, 3}));
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace());
EXPECT_EQ(p1, p2);
// set gpu_tensor a new dim with smaller size
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = gpu_tensor.mutable_data<float>(make_ddim({2, 2}));
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
EXPECT_EQ(p1, p2);
}
}
*/
TEST(Tensor, ShareDataFrom) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
dst_tensor.ShareDataFrom(src_tensor);
} catch (EnforceNotMet err) {
caught = true;
std::string msg = "Can not share data from an uninitialized tensor.";
const char* what = err.what();
for (size_t i = 0; i < msg.length(); ++i) {
ASSERT_EQ(what[i], msg[i]);
}
}
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
dst_tensor.ShareDataFrom(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
}
TEST(Tensor, Slice) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
Tensor slice_tensor = src_tensor.Slice(1, 3);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
EXPECT_EQ(slice_dims[2], 4);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
}
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
}
}
*/
\ No newline at end of file
add_subdirectory(dynload)
nv_test(cuda_test SRCS cuda_test.cu)
nv_test(cuda_test SRCS cuda_test.cu DEPS dyload_cuda)
cc_library(place SRCS place.cc)
cc_test(place_test SRCS place_test.cc DEPS place glog gflags)
IF(WITH_GPU)
set(GPU_CTX_DEPS dyload_cuda dynamic_loader )
ELSE()
set(GPU_CTX_DEPS)
ENDIF()
nv_test(device_context_test SRCS device_context_test.cc DEPS dynamic_loader place eigen3 glog gflags)
cc_library(device_context SRCS device_context.cc DEPS place eigen3 ${GPU_CTX_DEPS})
nv_test(device_context_test SRCS device_context_test.cc DEPS device_context glog gflags)
......@@ -28,19 +28,19 @@ inline void throw_on_error(cudaError_t e, const char* message) {
}
}
int GetDeviceCount(void) {
inline int GetDeviceCount(void) {
int count;
throw_on_error(cudaGetDeviceCount(&count), "cudaGetDeviceCount failed");
return count;
}
int GetCurrentDeviceId(void) {
inline int GetCurrentDeviceId(void) {
int device_id;
throw_on_error(cudaGetDevice(&device_id), "cudaGetDevice failed");
return device_id;
}
void SetDeviceId(int device_id) {
inline void SetDeviceId(int device_id) {
throw_on_error(cudaSetDevice(device_id), "cudaSetDevice failed");
}
......
#include <paddle/platform/device_context.h>
namespace paddle {
namespace platform {
namespace dynload {
namespace dummy {
// Make DeviceContext A library.
int DUMMY_VAR_FOR_DEV_CTX = 0;
} // namespace dummy
} // namespace dynload
} // namespace platform
} // namespace paddle
\ No newline at end of file
cc_library(dynamic_loader SRCS dynamic_loader.cc DEPS glog gflags)
nv_library(dyload_cuda SRCS cublas.cc cudnn.cc curand.cc)
#include <paddle/platform/dynload/cublas.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUBLAS_BLAS_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -23,8 +23,8 @@ namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cublas_dso_flag;
void *cublas_dso_handle = nullptr;
extern std::once_flag cublas_dso_flag;
extern void *cublas_dso_handle;
/**
* The following macro definition can generate structs
......@@ -34,10 +34,10 @@ void *cublas_dso_handle = nullptr;
* note: default dynamic linked libs
*/
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
inline cublasStatus_t operator()(Args... args) { \
typedef cublasStatus_t (*cublasFunc)(Args...); \
std::call_once(cublas_dso_flag, \
paddle::platform::dynload::GetCublasDsoHandle, \
......@@ -45,62 +45,46 @@ void *cublas_dso_handle = nullptr;
void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<cublasFunc>(p_##__name)(args...); \
} \
} __name; // struct DynLoad__##__name
}; \
extern DynLoad__##__name __name
#else
#define DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; // struct DynLoad__##__name
#define DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) \
struct DynLoad__##__name { \
inline template <typename... Args> \
cublasStatus_t operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
#define DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) DYNAMIC_LOAD_CUBLAS_WRAP(__name)
#define DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(__name) \
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name)
// include all needed cublas functions in HPPL
// clang-format off
#define CUBLAS_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasSgemv) \
__macro(cublasDgemv) \
__macro(cublasSgemm) \
__macro(cublasDgemm) \
__macro(cublasSgeam) \
__macro(cublasDgeam) \
__macro(cublasSgemv); \
__macro(cublasDgemv); \
__macro(cublasSgemm); \
__macro(cublasDgemm); \
__macro(cublasSgeam); \
__macro(cublasDgeam);
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode)
DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched)
DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetriBatched)
CUBLAS_BLAS_ROUTINE_EACH(DYNAMIC_LOAD_CUBLAS_V2_WRAP)
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasCreate);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasDestroy);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetStream);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasSetPointerMode);
DECLARE_DYNAMIC_LOAD_CUBLAS_V2_WRAP(cublasGetPointerMode);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasCgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasZgemmBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetrfBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasSgetriBatched);
DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(cublasDgetrfBatched);
#undef DYNAMIC_LOAD_CUBLAS_WRAP
#undef DYNAMIC_LOAD_CUBLAS_V2_WRAP
#undef CUBLAS_BLAS_ROUTINE_EACH
CUBLAS_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP);
// clang-format on
#ifndef PADDLE_TYPE_DOUBLE
#define CUBLAS_GEAM paddle::platform::dynload::cublasSgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasSgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasSgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasSgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasSgetriBatched
#else
#define CUBLAS_GEAM paddle::platform::dynload::cublasDgeam
#define CUBLAS_GEMV paddle::platform::dynload::cublasDgemv
#define CUBLAS_GEMM paddle::platform::dynload::cublasDgemm
#define CUBLAS_GETRF paddle::platform::dynload::cublasDgetrfBatched
#define CUBLAS_GETRI paddle::platform::dynload::cublasDgetriBatched
#endif
#undef DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
#include <paddle/platform/dynload/cudnn.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CUDNN_DNN_ROUTINE_EACH(DEFINE_WRAP);
CUDNN_DNN_ROUTINE_EACH_R2(DEFINE_WRAP);
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DEFINE_WRAP);
#endif
#ifdef CUDNN_DNN_ROUTINE_EACH_R5
CUDNN_DNN_ROUTINE_EACH_R5(DEFINE_WRAP);
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
\ No newline at end of file
......@@ -23,12 +23,12 @@ namespace paddle {
namespace platform {
namespace dynload {
std::once_flag cudnn_dso_flag;
void* cudnn_dso_handle = nullptr;
extern std::once_flag cudnn_dso_flag;
extern void* cudnn_dso_handle;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
......@@ -39,17 +39,19 @@ void* cudnn_dso_handle = nullptr;
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
}; \
extern struct DynLoad__##__name __name
#else
#define DYNAMIC_LOAD_CUDNN_WRAP(__name) \
#define DECLARE_DYNAMIC_LOAD_CUDNN_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
}; \
extern DynLoad__##__name __name
#endif
......@@ -57,80 +59,73 @@ void* cudnn_dso_handle = nullptr;
* include all needed cudnn functions in HPPL
* different cudnn version has different interfaces
**/
// clang-format off
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor) \
__macro(cudnnSetTensor4dDescriptorEx) \
__macro(cudnnGetConvolutionNdForwardOutputDim) \
__macro(cudnnGetConvolutionForwardAlgorithm) \
__macro(cudnnCreateTensorDescriptor) \
__macro(cudnnDestroyTensorDescriptor) \
__macro(cudnnCreateFilterDescriptor) \
__macro(cudnnSetFilter4dDescriptor) \
__macro(cudnnSetPooling2dDescriptor) \
__macro(cudnnDestroyFilterDescriptor) \
__macro(cudnnCreateConvolutionDescriptor) \
__macro(cudnnCreatePoolingDescriptor) \
__macro(cudnnDestroyPoolingDescriptor) \
__macro(cudnnSetConvolution2dDescriptor) \
__macro(cudnnDestroyConvolutionDescriptor) \
__macro(cudnnCreate) \
__macro(cudnnDestroy) \
__macro(cudnnSetStream) \
__macro(cudnnActivationForward) \
__macro(cudnnConvolutionForward) \
__macro(cudnnConvolutionBackwardBias) \
__macro(cudnnGetConvolutionForwardWorkspaceSize) \
__macro(cudnnTransformTensor) \
__macro(cudnnPoolingForward) \
__macro(cudnnPoolingBackward) \
__macro(cudnnSoftmaxBackward) \
__macro(cudnnSoftmaxForward) \
__macro(cudnnGetVersion) \
__macro(cudnnGetErrorString)
CUDNN_DNN_ROUTINE_EACH(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor) \
__macro(cudnnConvolutionBackwardData) \
__macro(cudnnConvolutionBackwardFilter)
CUDNN_DNN_ROUTINE_EACH_R2(DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH(__macro) \
__macro(cudnnSetTensor4dDescriptor); \
__macro(cudnnSetTensor4dDescriptorEx); \
__macro(cudnnGetConvolutionNdForwardOutputDim); \
__macro(cudnnGetConvolutionForwardAlgorithm); \
__macro(cudnnCreateTensorDescriptor); \
__macro(cudnnDestroyTensorDescriptor); \
__macro(cudnnCreateFilterDescriptor); \
__macro(cudnnSetFilter4dDescriptor); \
__macro(cudnnSetPooling2dDescriptor); \
__macro(cudnnDestroyFilterDescriptor); \
__macro(cudnnCreateConvolutionDescriptor); \
__macro(cudnnCreatePoolingDescriptor); \
__macro(cudnnDestroyPoolingDescriptor); \
__macro(cudnnSetConvolution2dDescriptor); \
__macro(cudnnDestroyConvolutionDescriptor); \
__macro(cudnnCreate); \
__macro(cudnnDestroy); \
__macro(cudnnSetStream); \
__macro(cudnnActivationForward); \
__macro(cudnnConvolutionForward); \
__macro(cudnnConvolutionBackwardBias); \
__macro(cudnnGetConvolutionForwardWorkspaceSize); \
__macro(cudnnTransformTensor); \
__macro(cudnnPoolingForward); \
__macro(cudnnPoolingBackward); \
__macro(cudnnSoftmaxBackward); \
__macro(cudnnSoftmaxForward); \
__macro(cudnnGetVersion); \
__macro(cudnnGetErrorString);
CUDNN_DNN_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#define CUDNN_DNN_ROUTINE_EACH_R2(__macro) \
__macro(cudnnAddTensor); \
__macro(cudnnConvolutionBackwardData); \
__macro(cudnnConvolutionBackwardFilter);
CUDNN_DNN_ROUTINE_EACH_R2(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
// APIs available after R3:
#if CUDNN_VERSION >= 3000
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
__macro(cudnnGetConvolutionBackwardDataAlgorithm) \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R3
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R3(__macro) \
__macro(cudnnGetConvolutionBackwardFilterWorkspaceSize); \
__macro(cudnnGetConvolutionBackwardDataAlgorithm); \
__macro(cudnnGetConvolutionBackwardFilterAlgorithm); \
__macro(cudnnGetConvolutionBackwardDataWorkspaceSize);
CUDNN_DNN_ROUTINE_EACH_AFTER_R3(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
// APIs available after R4:
#if CUDNN_VERSION >= 4007
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining) \
__macro(cudnnBatchNormalizationForwardInference) \
__macro(cudnnBatchNormalizationBackward)
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_AFTER_R4
#define CUDNN_DNN_ROUTINE_EACH_AFTER_R4(__macro) \
__macro(cudnnBatchNormalizationForwardTraining); \
__macro(cudnnBatchNormalizationForwardInference); \
__macro(cudnnBatchNormalizationBackward);
CUDNN_DNN_ROUTINE_EACH_AFTER_R4(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
// APIs in R5
#if CUDNN_VERSION >= 5000
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor) \
__macro(cudnnSetActivationDescriptor) \
__macro(cudnnGetActivationDescriptor) \
__macro(cudnnDestroyActivationDescriptor)
CUDNN_DNN_ROUTINE_EACH_R5(DYNAMIC_LOAD_CUDNN_WRAP)
#undef CUDNN_DNN_ROUTINE_EACH_R5
#define CUDNN_DNN_ROUTINE_EACH_R5(__macro) \
__macro(cudnnCreateActivationDescriptor); \
__macro(cudnnSetActivationDescriptor); \
__macro(cudnnGetActivationDescriptor); \
__macro(cudnnDestroyActivationDescriptor);
CUDNN_DNN_ROUTINE_EACH_R5(DECLARE_DYNAMIC_LOAD_CUDNN_WRAP)
#endif
#undef CUDNN_DNN_ROUTINE_EACH
// clang-format on
} // namespace dynload
} // namespace platform
} // namespace paddle
#include <paddle/platform/dynload/curand.h>
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
CURAND_RAND_ROUTINE_EACH(DEFINE_WRAP);
}
}
}
\ No newline at end of file
......@@ -22,10 +22,10 @@ limitations under the License. */
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag curand_dso_flag;
void *curand_dso_handle = nullptr;
extern std::once_flag curand_dso_flag;
extern void *curand_dso_handle;
#ifdef PADDLE_USE_DSO
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
#define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
......@@ -36,32 +36,29 @@ void *curand_dso_handle = nullptr;
void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \
} __name; /* struct DynLoad__##__name */
}; \
extern DynLoad__##__name __name
#else
#define DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
return __name(args...); \
} \
} __name; /* struct DynLoad__##__name */
#define DECLARE_DYNAMIC_LOAD_CURAND_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
curandStatus_t operator()(Args... args) { \
return __name(args...); \
} \
}; \
extern DynLoad__##__name __name
#endif
/* include all needed curand functions in HPPL */
// clang-format off
#define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator) \
__macro(curandSetStream) \
__macro(curandSetPseudoRandomGeneratorSeed)\
__macro(curandGenerateUniform) \
__macro(curandGenerateUniformDouble) \
__macro(curandDestroyGenerator)
// clang-format on
#define CURAND_RAND_ROUTINE_EACH(__macro) \
__macro(curandCreateGenerator); \
__macro(curandSetStream); \
__macro(curandSetPseudoRandomGeneratorSeed); \
__macro(curandGenerateUniform); \
__macro(curandGenerateUniformDouble); \
__macro(curandDestroyGenerator);
CURAND_RAND_ROUTINE_EACH(DYNAMIC_LOAD_CURAND_WRAP)
CURAND_RAND_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CURAND_WRAP);
#undef CURAND_RAND_ROUTINE_EACH
#undef DYNAMIC_LOAD_CURAND_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册