未验证 提交 24103cbb 编写于 作者: W Wilber 提交者: GitHub

[PTEN] Update gpu_context. (#39359)

* gpu_context..

* update

* update

* update
上级 0fee0044
......@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward());
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, T> {
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
void operator()(const DeviceContext& context, const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
......@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
* col =
* [input_channels, filter_height, filter_width, output_height, output_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, T> {
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& col,
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im,
......@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, double>;
template <class T>
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
......@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, T> {
template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& im, const std::vector<int>& dilation,
void operator()(const DeviceContext& context, const framework::Tensor& im,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) {
......@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
* col =
* [output_height, output_width, input_channels, filter_height, filter_width]
*/
template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, T> {
template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& col,
void operator()(const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilation,
const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im,
......@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, double>;
} // namespace math
} // namespace operators
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
......@@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Vol2ColFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& vol,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col->dims().size()));
// template <class DeviceContext, class T>
// class Vol2ColFunctor {
// public:
template <class DeviceContext, class T>
void Vol2ColFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& vol,
const std::vector<int>& dilations, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col->dims().size()));
int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(
input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(
input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
int num_outputs =
input_channels * output_depth * output_height * output_width;
int num_outputs =
input_channels * output_depth * output_height * output_width;
int max_threads = 1024;
int max_threads = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads);
platform::ChangeThreadNum(context, &max_threads);
#endif
const int threads = max_threads;
const int blocks = (num_outputs + max_threads - 1) / max_threads;
const int threads = max_threads;
const int blocks = (num_outputs + max_threads - 1) / max_threads;
vol2col<T><<<blocks, threads, 0, context.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>(),
data_layout);
}
};
vol2col<T><<<blocks, threads, 0, context.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>(),
data_layout);
}
// };
template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth,
......@@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
* [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width]
*/
template <class T>
class Col2VolFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& col,
const std::vector<int>& dilations,
const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col.dims().size()));
// template <class DeviceContext, class T>
// class Col2VolFunctor<DeviceContext, T> {
// public:
template <class DeviceContext, class T>
void Col2VolFunctor<DeviceContext, T>::operator()(
const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& dilations, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.",
vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), 7,
platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.",
col.dims().size()));
int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
int output_depth = col.dims()[4];
int output_height = col.dims()[5];
int output_width = col.dims()[6];
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(
input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(
input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] +
1;
PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth,
platform::errors::InvalidArgument(
"input_depth(%d) and output_depth(%d) are mismatching.",
input_depth_tmp, output_depth));
auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
((dilations[1] * (filter_height - 1) + 1))) /
strides[1] +
1;
PADDLE_ENFORCE_EQ(
input_height_tmp, output_height,
platform::errors::InvalidArgument(
"input_height(%d) and output_height(%d) are mismatching.",
input_height_tmp, output_height));
auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
((dilations[2] * (filter_width - 1) + 1))) /
strides[2] +
1;
PADDLE_ENFORCE_EQ(input_width_tmp, output_width,
platform::errors::InvalidArgument(
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
int num_kernels = input_channels * input_depth * input_height * input_width;
int num_kernels = input_channels * input_depth * input_height * input_width;
int max_threads = 1024;
int max_threads = 1024;
#ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads);
platform::ChangeThreadNum(context, &max_threads);
#endif
const int threads = max_threads;
const int blocks = (num_kernels + max_threads - 1) / max_threads;
const int threads = max_threads;
const int blocks = (num_kernels + max_threads - 1) / max_threads;
col2vol<T><<<blocks, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
data_layout);
}
};
col2vol<T><<<blocks, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
data_layout);
}
// };
template class Vol2ColFunctor<platform::CUDADeviceContext, float>;
template class Vol2ColFunctor<platform::CUDADeviceContext, double>;
template class Vol2ColFunctor<pten::GPUContext, float>;
template class Vol2ColFunctor<pten::GPUContext, double>;
template class Col2VolFunctor<platform::CUDADeviceContext, float>;
template class Col2VolFunctor<platform::CUDADeviceContext, double>;
template class Col2VolFunctor<pten::GPUContext, float>;
template class Col2VolFunctor<pten::GPUContext, double>;
} // namespace math
} // namespace operators
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/allocator.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
......@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: pten::GPUContext(place) {
pten::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset(
new stream::CUDAStream(pten::GPUContext::stream(), this->GetPlace()));
cuda_stream_.reset(new stream::CUDAStream(pten::GPUContext::stream(), place));
workspace_.reset(new pten::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, pten::GPUContext::stream())
.get()));
}
CUDADeviceContext::~CUDADeviceContext() = default;
......@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const {
pten::GPUContext::WaitStreamCallback();
}
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
pten::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
if (thread_ctx_.count(this)) {
// return workspace_.get();
return pten::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace(), pten::GPUContext::stream())
.get());
}
return pten::GPUContext::cudnn_workspace_handle();
}
gpuStream_t CUDADeviceContext::stream() const {
......
......@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext {
* workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const;
pten::DnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Return cuda stream in the device context. */
gpuStream_t stream() const;
......@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext {
// NOTE: Just for compatibility with the past, please delete if there is an
// elegant way.
std::unique_ptr<stream::CUDAStream> cuda_stream_;
std::unique_ptr<pten::DnnWorkspaceHandle> workspace_{nullptr};
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
};
......
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
#include <functional>
#include <future>
......@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream,
} // namespace internal
class DnnWorkspaceHandle {
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {}
inline void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
if (required_workspace_bytes > WorkspaceSize()) {
ReallocWorkspace(required_workspace_bytes);
}
VLOG(2) << "Cudnn workspace size at RunFunc: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
{
std::lock_guard<std::mutex> guard(mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; }
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
return 0;
}
return allocation_->size();
}
void ResetWorkspace() { allocation_ = nullptr; }
void ReallocWorkspace(size_t required_workspace_bytes) {
if (required_workspace_bytes <= WorkspaceSize()) return;
// reset allocation first before re-allocate to save memory
allocation_.reset();
allocation_ = allocator_->Allocate(required_workspace_bytes);
}
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
std::mutex mtx_;
};
void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
if (required_workspace_bytes <= WorkspaceSize()) return;
// reset allocation first before re-allocate to save memory
allocation_.reset();
allocation_ = allocator_->Allocate(required_workspace_bytes);
}
struct GPUContext::Impl {
void Init() {
......@@ -341,9 +301,15 @@ struct GPUContext::Impl {
}
}
DnnWorkspaceHandle* GetDnnWorkspace() {
PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
return workspace_;
// TODO(wilber): The return type is a pointer, to be modified later.
// DnnWorkspaceHandle* GetDnnWorkspace() {
// PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
// return workspace_;
// }
DnnWorkspaceHandle GetDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
return DnnWorkspaceHandle(allocator_);
}
void InitStream() {
......@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const {
return impl_->eigen_device();
}
DnnWorkspaceHandle* GPUContext::cudnn_workspace_handle() {
DnnWorkspaceHandle GPUContext::cudnn_workspace_handle() const {
return impl_->GetDnnWorkspace();
}
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <array>
#include <functional>
#include <mutex>
#include "paddle/pten/backends/gpu/forwards.h"
#include "paddle/pten/backends/gpu/gpu_decls.h"
#include "paddle/pten/backends/gpu/gpu_helper.h"
......@@ -24,7 +25,53 @@ limitations under the License. */
namespace pten {
class DnnWorkspaceHandle;
class DnnWorkspaceHandle {
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {
mtx_.reset(new std::mutex());
}
inline void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
if (required_workspace_bytes > WorkspaceSize()) {
ReallocWorkspace(required_workspace_bytes);
}
{
std::lock_guard<std::mutex> guard(*mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
return 0;
}
return allocation_->size();
}
void ResetWorkspace();
void ReallocWorkspace(size_t required_workspace_bytes);
DnnWorkspaceHandle(DnnWorkspaceHandle&&) = default;
DnnWorkspaceHandle& operator=(DnnWorkspaceHandle&&) = delete;
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
std::unique_ptr<std::mutex> mtx_;
};
class GPUContext : public DeviceContext {
public:
......@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext {
* would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released.
*/
DnnWorkspaceHandle* cudnn_workspace_handle();
// TODO(wilber): The return type is a pointer, to be modified later.
DnnWorkspaceHandle cudnn_workspace_handle() const;
public:
/*! \brief Call cublas function safely. */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册