未验证 提交 5a8bd82c 编写于 作者: C chengduo 提交者: GitHub

Remove workspace_handle (#15376)

* remove workspace_handle
test=develop

* set constant for loss
test=develop
上级 f534c66d
...@@ -104,7 +104,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -104,7 +104,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv algorithm --------------------- // ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo; cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType( CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH)); cudnn_conv_desc, CUDNN_DEFAULT_MATH));
...@@ -118,19 +120,24 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -118,19 +120,24 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo)); workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
} else { } else {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
auto search_func = [&]() { auto search_func = [&]() {
int returned_algo_count; int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS> std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat; fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE( CUDNN_ENFORCE(platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc, handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count, kNUM_CUDNN_FWD_ALGS, &returned_algo_count, fwd_perf_stat.data(),
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit)); cudnn_workspace_ptr, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)"; VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) { for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i]; const auto& stat = fwd_perf_stat[i];
...@@ -181,6 +188,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -181,6 +188,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit, PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit"); "workspace_size to be allocated exceeds the limit");
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
if ((activation == "identity") && (!residual)) { if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is // Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib. // enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
...@@ -188,13 +204,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -188,13 +204,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// cudnnConvolutionForward and cudnnAddTensor // cudnnConvolutionForward and cudnnAddTensor
// ------------- cudnn conv forward and bias add --------------------- // ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace, filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data)); workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnAddTensor( CUDNN_ENFORCE(platform::dynload::cudnnAddTensor(
handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc, handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
output_data)); output_data));
...@@ -205,15 +220,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -205,15 +220,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv+bias+act forward -------------------- // ------------------- cudnn conv+bias+act forward --------------------
ScalingParamType<T> alpha1 = 1.0f; ScalingParamType<T> alpha1 = 1.0f;
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f; ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc, handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace, filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr,
workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data, workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc, cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
output_data)); output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels"); std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) { if (channels.size()) {
......
...@@ -104,16 +104,18 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> { ...@@ -104,16 +104,18 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups; int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g, handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc, cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta, algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g)); cudnn_output_desc, output_data + output_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
}; };
...@@ -209,20 +211,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -209,20 +211,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups; output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups; int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f; T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc, handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc, output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo, filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g)); input_grad_data + input_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
...@@ -232,15 +236,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -232,15 +236,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad. // Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter // Gradient with respect to the filter
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc, handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc, output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo, input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
cudnn_filter_desc, filter_grad_data + filter_offset * g)); filter_grad_data + filter_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
} }
} }
......
...@@ -216,18 +216,19 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> { ...@@ -216,18 +216,19 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
out_datas.push_back( out_datas.push_back(
static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w)); static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w));
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
auto func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward( CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha, in_desc[i], in_datas[i], filter_desc[i], handle, &alpha, in_desc[i], in_datas[i], filter_desc[i],
static_cast<const void*>(filters[i]->data<T>()), conv_desc[i], static_cast<const void*>(filters[i]->data<T>()), conv_desc[i],
algo[i], cudnn_workspace, workspace_size_in_bytes, &beta, algo[i], cudnn_workspace, workspace_size_in_bytes, &beta, out_desc[i],
out_desc[i], out_datas[i], bias_desc[i], out_datas[i], bias_desc[i],
static_cast<const void*>(bias[i]->data<T>()), cudnn_act_desc, static_cast<const void*>(bias[i]->data<T>()), cudnn_act_desc,
out_desc[i], out_datas[i])); out_desc[i], out_datas[i]));
};
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
workspace_handle.RunFunc(func, workspace_size_in_bytes);
} }
cudnnTensorDescriptor_t x_desc; cudnnTensorDescriptor_t x_desc;
......
...@@ -144,17 +144,19 @@ class CudnnCTCKernel : public framework::OpKernel<T> { ...@@ -144,17 +144,19 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size)); CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace()); T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss, static_cast<T>(0));
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size);
void* cudnn_workspace = temp_allocation->ptr();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss( CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data, handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(), warpctc_label_lengths.data(), warpctc_logits_lengths.data(), loss_data,
loss_data, cu_grad_desc, warpctc_grad_data, cu_grad_desc, warpctc_grad_data, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace, cu_ctcloss_desc, cudnn_workspace, workspace_size));
workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册