未验证 提交 24103cbb 编写于 作者: W Wilber 提交者: GitHub

[PTEN] Update gpu_context. (#39359)

* gpu_context..

* update

* update

* update
上级 0fee0044
...@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -288,7 +288,6 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache = AlgorithmsCache<algo_t>& algo_cache =
*(framework::ConvSearchCache::Instance().GetForward()); *(framework::ConvSearchCache::Instance().GetForward());
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/im2col.h" #include "paddle/fluid/operators/math/im2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, ...@@ -73,12 +74,12 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
platform::CUDADeviceContext, T> { T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& im,
const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col, const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, ...@@ -184,12 +185,11 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
* col = * col =
* [input_channels, filter_height, filter_width, output_height, output_width] * [input_channels, filter_height, filter_width, output_height, output_width]
*/ */
template <class T> template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, DeviceContext,
platform::CUDADeviceContext, T> { T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& col,
const framework::Tensor& col,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im, const std::vector<int>& padding, framework::Tensor* im,
...@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -257,10 +257,18 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, float>; platform::CUDADeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, double>; platform::CUDADeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, float>; platform::CUDADeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CUDADeviceContext, double>; platform::CUDADeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
pten::GPUContext, double>;
template <class T> template <class T>
__global__ void im2colOCF(const T* im_data, int im_channels, int im_height, __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
...@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height, ...@@ -299,12 +307,12 @@ __global__ void im2colOCF(const T* im_data, int im_channels, int im_height,
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class DeviceContext, class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
platform::CUDADeviceContext, T> { T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& im,
const framework::Tensor& im, const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* col, const std::vector<int>& padding, framework::Tensor* col,
const DataLayout data_layout) { const DataLayout data_layout) {
...@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height, ...@@ -390,12 +398,11 @@ __global__ void col2imOCF(const T* col_data, int im_channels, int im_height,
* col = * col =
* [output_height, output_width, input_channels, filter_height, filter_width] * [output_height, output_width, input_channels, filter_height, filter_width]
*/ */
template <class T> template <class DeviceContext, class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, DeviceContext,
platform::CUDADeviceContext, T> { T> {
public: public:
void operator()(const platform::CUDADeviceContext& context, void operator()(const DeviceContext& context, const framework::Tensor& col,
const framework::Tensor& col,
const std::vector<int>& dilation, const std::vector<int>& dilation,
const std::vector<int>& stride, const std::vector<int>& stride,
const std::vector<int>& padding, framework::Tensor* im, const std::vector<int>& padding, framework::Tensor* im,
...@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -464,10 +471,19 @@ template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, float>; platform::CUDADeviceContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, double>; platform::CUDADeviceContext, double>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, float>;
template class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, float>; platform::CUDADeviceContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CUDADeviceContext, double>; platform::CUDADeviceContext, double>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, float>;
template class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
pten::GPUContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/vol2col.h" #include "paddle/fluid/operators/math/vol2col.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, ...@@ -82,93 +83,91 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
* [input_channels, filter_depth, filter_height, filter_width, * [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width] * output_depth, output_height, output_width]
*/ */
template <class T> // template <class DeviceContext, class T>
class Vol2ColFunctor<platform::CUDADeviceContext, T> { // class Vol2ColFunctor {
public: // public:
void operator()(const platform::CUDADeviceContext& context, template <class DeviceContext, class T>
const framework::Tensor& vol, void Vol2ColFunctor<DeviceContext, T>::operator()(
const std::vector<int>& dilations, const DeviceContext& context, const framework::Tensor& vol,
const std::vector<int>& strides, const std::vector<int>& dilations, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* col, const std::vector<int>& paddings, framework::Tensor* col,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol.dims().size(), 4, PADDLE_ENFORCE_EQ(vol.dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol.dims().size())); vol.dims().size()));
PADDLE_ENFORCE_EQ(col->dims().size(), 7, PADDLE_ENFORCE_EQ(col->dims().size(), 7,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col->dims().size())); col->dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth = int input_depth =
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height = int input_height =
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width = int input_width =
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col->dims()[2]; int filter_height = col->dims()[2];
int filter_width = col->dims()[3]; int filter_width = col->dims()[3];
int output_depth = col->dims()[4]; int output_depth = col->dims()[4];
int output_height = col->dims()[5]; int output_height = col->dims()[5];
int output_width = col->dims()[6]; int output_width = col->dims()[6];
bool paddings_size_is_6 = (paddings.size() == 6); bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] + strides[0] +
1; 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth,
input_depth_tmp, output_depth, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "input_depth(%d) and output_depth(%d) are mismatching.",
"input_depth(%d) and output_depth(%d) are mismatching.", input_depth_tmp, output_depth));
input_depth_tmp, output_depth)); auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
auto input_height_tmp = (input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) /
((dilations[1] * (filter_height - 1) + 1))) / strides[1] +
strides[1] + 1;
1; PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( input_height_tmp, output_height,
input_height_tmp, output_height, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "input_height(%d) and output_height(%d) are mismatching.",
"input_height(%d) and output_height(%d) are mismatching.", input_height_tmp, output_height));
input_height_tmp, output_height)); auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
auto input_width_tmp = (input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) /
((dilations[2] * (filter_width - 1) + 1))) / strides[2] +
strides[2] + 1;
1; PADDLE_ENFORCE_EQ(input_width_tmp, output_width,
PADDLE_ENFORCE_EQ( platform::errors::InvalidArgument(
input_width_tmp, output_width, "input_width(%d) and output_width(%d) are mismatching.",
platform::errors::InvalidArgument( input_width_tmp, output_width));
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
int num_outputs = int num_outputs =
input_channels * output_depth * output_height * output_width; input_channels * output_depth * output_height * output_width;
int max_threads = 1024; int max_threads = 1024;
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads); platform::ChangeThreadNum(context, &max_threads);
#endif #endif
const int threads = max_threads; const int threads = max_threads;
const int blocks = (num_outputs + max_threads - 1) / max_threads; const int blocks = (num_outputs + max_threads - 1) / max_threads;
vol2col<T><<<blocks, threads, 0, context.stream()>>>( vol2col<T><<<blocks, threads, 0, context.stream()>>>(
num_outputs, vol.data<T>(), input_depth, input_height, input_width, num_outputs, vol.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, col->data<T>(), pad_w_left, output_depth, output_height, output_width, col->data<T>(),
data_layout); data_layout);
} }
}; // };
template <class T> template <class T>
__global__ void col2vol(int num_kernels, const T* data_col, int depth, __global__ void col2vol(int num_kernels, const T* data_col, int depth,
...@@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, ...@@ -249,98 +248,101 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
* [input_channels, filter_depth, filter_height, filter_width, * [input_channels, filter_depth, filter_height, filter_width,
* output_depth, output_height, output_width] * output_depth, output_height, output_width]
*/ */
template <class T> // template <class DeviceContext, class T>
class Col2VolFunctor<platform::CUDADeviceContext, T> { // class Col2VolFunctor<DeviceContext, T> {
public: // public:
void operator()(const platform::CUDADeviceContext& context, template <class DeviceContext, class T>
const framework::Tensor& col, void Col2VolFunctor<DeviceContext, T>::operator()(
const std::vector<int>& dilations, const DeviceContext& context, const framework::Tensor& col,
const std::vector<int>& strides, const std::vector<int>& dilations, const std::vector<int>& strides,
const std::vector<int>& paddings, framework::Tensor* vol, const std::vector<int>& paddings, framework::Tensor* vol,
const DataLayout data_layout) const { const DataLayout data_layout) const {
PADDLE_ENFORCE_EQ(vol->dims().size(), 4, PADDLE_ENFORCE_EQ(vol->dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimension of vol should be 4, but received %d.", "The dimension of vol should be 4, but received %d.",
vol->dims().size())); vol->dims().size()));
PADDLE_ENFORCE_EQ(col.dims().size(), 7, PADDLE_ENFORCE_EQ(col.dims().size(), 7,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The dimension of col should be 7, but received %d.", "The dimension of col should be 7, but received %d.",
col.dims().size())); col.dims().size()));
int input_channels = int input_channels =
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth = int input_depth =
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height = int input_height =
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width = int input_width =
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
int output_depth = col.dims()[4]; int output_depth = col.dims()[4];
int output_height = col.dims()[5]; int output_height = col.dims()[5];
int output_width = col.dims()[6]; int output_width = col.dims()[6];
bool paddings_size_is_6 = (paddings.size() == 6); bool paddings_size_is_6 = (paddings.size() == 6);
int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0]; int pad_d_forth = paddings_size_is_6 ? paddings[0] : paddings[0];
int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0]; int pad_d_back = paddings_size_is_6 ? paddings[1] : paddings[0];
int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1]; int pad_h_up = paddings_size_is_6 ? paddings[2] : paddings[1];
int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1]; int pad_h_down = paddings_size_is_6 ? paddings[3] : paddings[1];
int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2]; int pad_w_left = paddings_size_is_6 ? paddings[4] : paddings[2];
int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2]; int pad_w_right = paddings_size_is_6 ? paddings[5] : paddings[2];
auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back - auto input_depth_tmp = (input_depth + pad_d_forth + pad_d_back -
((dilations[0] * (filter_depth - 1) + 1))) / ((dilations[0] * (filter_depth - 1) + 1))) /
strides[0] + strides[0] +
1; 1;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(input_depth_tmp, output_depth,
input_depth_tmp, output_depth, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "input_depth(%d) and output_depth(%d) are mismatching.",
"input_depth(%d) and output_depth(%d) are mismatching.", input_depth_tmp, output_depth));
input_depth_tmp, output_depth)); auto input_height_tmp = (input_height + pad_h_up + pad_h_down -
auto input_height_tmp = (input_height + pad_h_up + pad_h_down - ((dilations[1] * (filter_height - 1) + 1))) /
((dilations[1] * (filter_height - 1) + 1))) / strides[1] +
strides[1] + 1;
1; PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ( input_height_tmp, output_height,
input_height_tmp, output_height, platform::errors::InvalidArgument(
platform::errors::InvalidArgument( "input_height(%d) and output_height(%d) are mismatching.",
"input_height(%d) and output_height(%d) are mismatching.", input_height_tmp, output_height));
input_height_tmp, output_height)); auto input_width_tmp = (input_width + pad_w_left + pad_w_right -
auto input_width_tmp = (input_width + pad_w_left + pad_w_right - ((dilations[2] * (filter_width - 1) + 1))) /
((dilations[2] * (filter_width - 1) + 1))) / strides[2] +
strides[2] + 1;
1; PADDLE_ENFORCE_EQ(input_width_tmp, output_width,
PADDLE_ENFORCE_EQ( platform::errors::InvalidArgument(
input_width_tmp, output_width, "input_width(%d) and output_width(%d) are mismatching.",
platform::errors::InvalidArgument( input_width_tmp, output_width));
"input_width(%d) and output_width(%d) are mismatching.",
input_width_tmp, output_width));
int num_kernels = input_channels * input_depth * input_height * input_width; int num_kernels = input_channels * input_depth * input_height * input_width;
int max_threads = 1024; int max_threads = 1024;
#ifdef WITH_NV_JETSON #ifdef WITH_NV_JETSON
platform::ChangeThreadNum(context, &max_threads); platform::ChangeThreadNum(context, &max_threads);
#endif #endif
const int threads = max_threads; const int threads = max_threads;
const int blocks = (num_kernels + max_threads - 1) / max_threads; const int blocks = (num_kernels + max_threads - 1) / max_threads;
col2vol<T><<<blocks, threads, 0, context.stream()>>>( col2vol<T><<<blocks, threads, 0, context.stream()>>>(
num_kernels, col.data<T>(), input_depth, input_height, input_width, num_kernels, col.data<T>(), input_depth, input_height, input_width,
dilations[0], dilations[1], dilations[2], filter_depth, filter_height, dilations[0], dilations[1], dilations[2], filter_depth, filter_height,
filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up, filter_width, strides[0], strides[1], strides[2], pad_d_forth, pad_h_up,
pad_w_left, output_depth, output_height, output_width, vol->data<T>(), pad_w_left, output_depth, output_height, output_width, vol->data<T>(),
data_layout); data_layout);
} }
}; // };
template class Vol2ColFunctor<platform::CUDADeviceContext, float>; template class Vol2ColFunctor<platform::CUDADeviceContext, float>;
template class Vol2ColFunctor<platform::CUDADeviceContext, double>; template class Vol2ColFunctor<platform::CUDADeviceContext, double>;
template class Vol2ColFunctor<pten::GPUContext, float>;
template class Vol2ColFunctor<pten::GPUContext, double>;
template class Col2VolFunctor<platform::CUDADeviceContext, float>; template class Col2VolFunctor<platform::CUDADeviceContext, float>;
template class Col2VolFunctor<platform::CUDADeviceContext, double>; template class Col2VolFunctor<platform::CUDADeviceContext, double>;
template class Col2VolFunctor<pten::GPUContext, float>;
template class Col2VolFunctor<pten::GPUContext, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/stream/cuda_stream.h" #include "paddle/fluid/platform/stream/cuda_stream.h"
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/allocator.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
...@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() { ...@@ -485,8 +486,11 @@ CUDAContext::~CUDAContext() {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
: pten::GPUContext(place) { : pten::GPUContext(place) {
pten::GPUContext::PartialInitWithoutAllocator(); pten::GPUContext::PartialInitWithoutAllocator();
cuda_stream_.reset( cuda_stream_.reset(new stream::CUDAStream(pten::GPUContext::stream(), place));
new stream::CUDAStream(pten::GPUContext::stream(), this->GetPlace())); workspace_.reset(new pten::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, pten::GPUContext::stream())
.get()));
} }
CUDADeviceContext::~CUDADeviceContext() = default; CUDADeviceContext::~CUDADeviceContext() = default;
...@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const { ...@@ -571,8 +575,15 @@ void CUDADeviceContext::WaitStreamCallback() const {
pten::GPUContext::WaitStreamCallback(); pten::GPUContext::WaitStreamCallback();
} }
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const { pten::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_); if (thread_ctx_.count(this)) {
// return workspace_.get();
return pten::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace(), pten::GPUContext::stream())
.get());
}
return pten::GPUContext::cudnn_workspace_handle();
} }
gpuStream_t CUDADeviceContext::stream() const { gpuStream_t CUDADeviceContext::stream() const {
......
...@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext { ...@@ -566,7 +566,7 @@ class CUDADeviceContext : public pten::GPUContext {
* workspace. Once the handle is destructed, the lock would be released. * workspace. Once the handle is destructed, the lock would be released.
* CudnnWorkspaceHandle is an RAII object to implement thread-safe * CudnnWorkspaceHandle is an RAII object to implement thread-safe
* sequential cudnn function calls. */ * sequential cudnn function calls. */
CudnnWorkspaceHandle cudnn_workspace_handle() const; pten::DnnWorkspaceHandle cudnn_workspace_handle() const;
/*! \brief Return cuda stream in the device context. */ /*! \brief Return cuda stream in the device context. */
gpuStream_t stream() const; gpuStream_t stream() const;
...@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext { ...@@ -607,6 +607,7 @@ class CUDADeviceContext : public pten::GPUContext {
// NOTE: Just for compatibility with the past, please delete if there is an // NOTE: Just for compatibility with the past, please delete if there is an
// elegant way. // elegant way.
std::unique_ptr<stream::CUDAStream> cuda_stream_; std::unique_ptr<stream::CUDAStream> cuda_stream_;
std::unique_ptr<pten::DnnWorkspaceHandle> workspace_{nullptr};
DISABLE_COPY_AND_ASSIGN(CUDADeviceContext); DISABLE_COPY_AND_ASSIGN(CUDADeviceContext);
}; };
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h" #include "paddle/pten/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array> #include <array>
#include <functional> #include <functional>
#include <future> #include <future>
...@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream, ...@@ -153,55 +154,14 @@ static void StreamCallbackFunc(gpuStream_t stream,
} // namespace internal } // namespace internal
class DnnWorkspaceHandle { void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; }
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {}
inline void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
if (required_workspace_bytes > WorkspaceSize()) {
ReallocWorkspace(required_workspace_bytes);
}
VLOG(2) << "Cudnn workspace size at RunFunc: "
<< static_cast<double>(WorkspaceSize()) / (1 << 20) << " MB";
{
std::lock_guard<std::mutex> guard(mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
inline size_t WorkspaceSize() { void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
if (allocation_ == nullptr) { if (required_workspace_bytes <= WorkspaceSize()) return;
return 0; // reset allocation first before re-allocate to save memory
} allocation_.reset();
return allocation_->size(); allocation_ = allocator_->Allocate(required_workspace_bytes);
} }
void ResetWorkspace() { allocation_ = nullptr; }
void ReallocWorkspace(size_t required_workspace_bytes) {
if (required_workspace_bytes <= WorkspaceSize()) return;
// reset allocation first before re-allocate to save memory
allocation_.reset();
allocation_ = allocator_->Allocate(required_workspace_bytes);
}
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
std::mutex mtx_;
};
struct GPUContext::Impl { struct GPUContext::Impl {
void Init() { void Init() {
...@@ -341,9 +301,15 @@ struct GPUContext::Impl { ...@@ -341,9 +301,15 @@ struct GPUContext::Impl {
} }
} }
DnnWorkspaceHandle* GetDnnWorkspace() { // TODO(wilber): The return type is a pointer, to be modified later.
PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr."); // DnnWorkspaceHandle* GetDnnWorkspace() {
return workspace_; // PD_CHECK(workspace_ != nullptr, "the gpu cudnn workspace is nullptr.");
// return workspace_;
// }
DnnWorkspaceHandle GetDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
return DnnWorkspaceHandle(allocator_);
} }
void InitStream() { void InitStream() {
...@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const { ...@@ -797,7 +763,7 @@ Eigen::GpuDevice* GPUContext::eigen_device() const {
return impl_->eigen_device(); return impl_->eigen_device();
} }
DnnWorkspaceHandle* GPUContext::cudnn_workspace_handle() { DnnWorkspaceHandle GPUContext::cudnn_workspace_handle() const {
return impl_->GetDnnWorkspace(); return impl_->GetDnnWorkspace();
} }
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <array> #include <array>
#include <functional> #include <functional>
#include <mutex>
#include "paddle/pten/backends/gpu/forwards.h" #include "paddle/pten/backends/gpu/forwards.h"
#include "paddle/pten/backends/gpu/gpu_decls.h" #include "paddle/pten/backends/gpu/gpu_decls.h"
#include "paddle/pten/backends/gpu/gpu_helper.h" #include "paddle/pten/backends/gpu/gpu_helper.h"
...@@ -24,7 +25,53 @@ limitations under the License. */ ...@@ -24,7 +25,53 @@ limitations under the License. */
namespace pten { namespace pten {
class DnnWorkspaceHandle; class DnnWorkspaceHandle {
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {
mtx_.reset(new std::mutex());
}
inline void RunFunc(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
if (required_workspace_bytes > WorkspaceSize()) {
ReallocWorkspace(required_workspace_bytes);
}
{
std::lock_guard<std::mutex> guard(*mtx_);
cudnn_func(allocation_ ? allocation_->ptr() : nullptr);
}
}
/*! \brief Thread which call RunFuncSync() would release gpu memory after
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
return 0;
}
return allocation_->size();
}
void ResetWorkspace();
void ReallocWorkspace(size_t required_workspace_bytes);
DnnWorkspaceHandle(DnnWorkspaceHandle&&) = default;
DnnWorkspaceHandle& operator=(DnnWorkspaceHandle&&) = delete;
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
std::unique_ptr<std::mutex> mtx_;
};
class GPUContext : public DeviceContext { class GPUContext : public DeviceContext {
public: public:
...@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext { ...@@ -85,7 +132,8 @@ class GPUContext : public DeviceContext {
* would be acquired to prevent other threads from accessing the * would be acquired to prevent other threads from accessing the
* workspace. Once the handle is destructed, the lock would be released. * workspace. Once the handle is destructed, the lock would be released.
*/ */
DnnWorkspaceHandle* cudnn_workspace_handle(); // TODO(wilber): The return type is a pointer, to be modified later.
DnnWorkspaceHandle cudnn_workspace_handle() const;
public: public:
/*! \brief Call cublas function safely. */ /*! \brief Call cublas function safely. */
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册