未验证 提交 0e9b393b 编写于 作者: D dzhwinter 提交者: GitHub

"derived cudnnDevice context" (#6585)

* "derived cudnnDevice context"

* "leave remove cudnn handle from CUDADeviceContext"

* "fix math function error"
上级 49b8ac80
...@@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>( ...@@ -273,6 +273,13 @@ void set_constant_with_place<platform::GPUPlace>(
TensorSetConstantGPU(context, tensor, value)); TensorSetConstantGPU(context, tensor, value));
} }
template <>
void set_constant_with_place<platform::CudnnPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
}
template struct RowwiseAdd<platform::CUDADeviceContext, float>; template struct RowwiseAdd<platform::CUDADeviceContext, float>;
template struct RowwiseAdd<platform::CUDADeviceContext, double>; template struct RowwiseAdd<platform::CUDADeviceContext, double>;
template struct ColwiseSum<platform::CUDADeviceContext, float>; template struct ColwiseSum<platform::CUDADeviceContext, float>;
......
...@@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; } ...@@ -125,6 +125,22 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() const { return cudnn_handle_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; } cudaStream_t CUDADeviceContext::stream() const { return stream_; }
CudnnDeviceContext::CudnnDeviceContext(CudnnPlace place)
: CUDADeviceContext(place), place_(place) {
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream()));
}
CudnnDeviceContext::~CudnnDeviceContext() {
SetDeviceId(place_.device);
Wait();
PADDLE_ENFORCE(dynload::cudnnDestroy(cudnn_handle_));
}
Place CudnnDeviceContext::GetPlace() const { return CudnnPlace(); }
cudnnHandle_t CudnnDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif #endif
} // namespace platform } // namespace platform
......
...@@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext { ...@@ -86,6 +86,22 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
}; };
class CudnnDeviceContext : public CUDADeviceContext {
public:
explicit CudnnDeviceContext(CudnnPlace place);
virtual ~CudnnDeviceContext();
/*! \brief Return place in the device context. */
Place GetPlace() const final;
/*! \brief Return cudnn handle in the device context. */
cudnnHandle_t cudnn_handle() const;
private:
cudnnHandle_t cudnn_handle_;
CudnnPlace place_;
};
#endif #endif
} // namespace platform } // namespace platform
......
...@@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) { ...@@ -46,3 +46,19 @@ TEST(Device, CUDADeviceContext) {
delete device_context; delete device_context;
} }
} }
TEST(Device, CudnnDeviceContext) {
using paddle::platform::CudnnDeviceContext;
using paddle::platform::CudnnPlace;
if (paddle::platform::dynload::HasCUDNN()) {
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
CudnnDeviceContext* device_context =
new CudnnDeviceContext(CudnnPlace(i));
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
ASSERT_NE(nullptr, cudnn_handle);
ASSERT_NE(nullptr, device_context->stream());
delete device_context;
}
}
}
...@@ -43,6 +43,11 @@ struct GPUPlace { ...@@ -43,6 +43,11 @@ struct GPUPlace {
int device; int device;
}; };
struct CudnnPlace : public GPUPlace {
CudnnPlace() : GPUPlace() {}
explicit CudnnPlace(int d) : GPUPlace(d) {}
};
struct IsGPUPlace : public boost::static_visitor<bool> { struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; } bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
...@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> { ...@@ -52,7 +57,7 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4 #define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<GPUPlace, CPUPlace> Place; typedef boost::variant<CudnnPlace, GPUPlace, CPUPlace> Place;
// static check number of place types is less equal than // static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT) // 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册