未验证 提交 5a8bd82c 编写于 作者: C chengduo 提交者: GitHub

Remove workspace_handle (#15376)

* remove workspace_handle
test=develop

* set constant for loss
test=develop
上级 f534c66d
......@@ -104,7 +104,9 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv algorithm ---------------------
cudnnConvolutionFwdAlgo_t algo;
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
Tensor cudnn_workspace;
void* cudnn_workspace_ptr = nullptr;
CUDNN_ENFORCE(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
......@@ -118,19 +120,24 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_limit)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
auto search_func = [&]() {
int returned_algo_count;
std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
fwd_perf_stat;
auto cudnn_find_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(
platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
CUDNN_ENFORCE(platform::dynload::cudnnFindConvolutionForwardAlgorithmEx(
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFunc(cudnn_find_func, workspace_size_limit);
kNUM_CUDNN_FWD_ALGS, &returned_algo_count, fwd_perf_stat.data(),
cudnn_workspace_ptr, workspace_size_limit));
VLOG(3) << "Perf result: (algo: stat, time, memory)";
for (int i = 0; i < returned_algo_count; ++i) {
const auto& stat = fwd_perf_stat[i];
......@@ -181,6 +188,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_LE(workspace_size_in_bytes, workspace_size_limit,
"workspace_size to be allocated exceeds the limit");
if (!cudnn_workspace_ptr) {
cudnn_workspace =
ctx.AllocateTmpTensor<int8_t, platform::CUDADeviceContext>(
framework::make_ddim(
{static_cast<int64_t>(workspace_size_in_bytes)}),
dev_ctx);
cudnn_workspace_ptr = static_cast<void*>(cudnn_workspace.data<int8_t>());
}
if ((activation == "identity") && (!residual)) {
// Only the CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM algo is
// enabled with CUDNN_ACTIVATION_IDENTITY in cuDNN lib.
......@@ -188,13 +204,12 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// cudnnConvolutionForward and cudnnAddTensor
// ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
CUDNN_ENFORCE(platform::dynload::cudnnAddTensor(
handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
output_data));
......@@ -205,15 +220,13 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv+bias+act forward --------------------
ScalingParamType<T> alpha1 = 1.0f;
ScalingParamType<T> alpha2 = residual ? 1.0f : 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha1, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
filter_data, cudnn_conv_desc, algo, cudnn_workspace_ptr,
workspace_size_in_bytes, &alpha2, cudnn_output_desc, residual_data,
cudnn_bias_desc, bias_data, cudnn_act_desc, cudnn_output_desc,
output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
std::vector<int> channels = ctx.Attr<std::vector<int>>("split_channels");
if (channels.size()) {
......
......@@ -104,16 +104,18 @@ class CUDNNConvTransposeOpKernel : public framework::OpKernel<T> {
int output_offset = output->numel() / output->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardData(
handle, &alpha, cudnn_filter_desc, filter_data + filter_offset * g,
cudnn_input_desc, input_data + input_offset * g, cudnn_conv_desc,
algo, cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_output_desc, output_data + output_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
};
......@@ -209,20 +211,22 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
output_grad->numel() / output_grad->dims()[0] / groups;
int filter_offset = filter->numel() / groups;
T alpha = 1.0f, beta = 0.0f;
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad.
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_filter_desc,
filter_data + filter_offset * g, cudnn_conv_desc, data_algo,
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_input_desc,
input_grad_data + input_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
}
......@@ -232,15 +236,12 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
// Because beta is zero, it is unnecessary to reset filter_grad.
// Gradient with respect to the filter
for (int g = 0; g < groups; g++) {
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBackwardFilter(
handle, &alpha, cudnn_output_desc,
output_grad_data + output_grad_offset * g, cudnn_input_desc,
input_data + input_offset * g, cudnn_conv_desc, filter_algo,
cudnn_workspace, workspace_size_in_bytes, &beta,
cudnn_filter_desc, filter_grad_data + filter_offset * g));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
cudnn_workspace, workspace_size_in_bytes, &beta, cudnn_filter_desc,
filter_grad_data + filter_offset * g));
}
}
}
......
......@@ -216,18 +216,19 @@ class CUDNNConvInceptionFusionOpKernel : public framework::OpKernel<T> {
out_datas.push_back(
static_cast<void*>(output_data + (oc0 + oc1 + oc2) * h * w));
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size_in_bytes);
void* cudnn_workspace = temp_allocation->ptr();
for (int i = 0; i < 4; ++i) {
auto func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnConvolutionBiasActivationForward(
handle, &alpha, in_desc[i], in_datas[i], filter_desc[i],
static_cast<const void*>(filters[i]->data<T>()), conv_desc[i],
algo[i], cudnn_workspace, workspace_size_in_bytes, &beta,
out_desc[i], out_datas[i], bias_desc[i],
algo[i], cudnn_workspace, workspace_size_in_bytes, &beta, out_desc[i],
out_datas[i], bias_desc[i],
static_cast<const void*>(bias[i]->data<T>()), cudnn_act_desc,
out_desc[i], out_datas[i]));
};
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
workspace_handle.RunFunc(func, workspace_size_in_bytes);
}
cudnnTensorDescriptor_t x_desc;
......
......@@ -144,17 +144,19 @@ class CudnnCTCKernel : public framework::OpKernel<T> {
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, &workspace_size));
T* loss_data = loss->mutable_data<T>(loss_dims, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), loss, static_cast<T>(0));
auto temp_allocation =
platform::DeviceTemporaryAllocator::Instance().Get(dev_ctx).Allocate(
workspace_size);
void* cudnn_workspace = temp_allocation->ptr();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
auto cudnn_func = [&](void* cudnn_workspace) {
CUDNN_ENFORCE(platform::dynload::cudnnCTCLoss(
handle, cu_logits_desc, warpctc_logits_data, warpctc_label_data,
warpctc_label_lengths.data(), warpctc_logits_lengths.data(),
loss_data, cu_grad_desc, warpctc_grad_data,
CUDNN_CTC_LOSS_ALGO_DETERMINISTIC, cu_ctcloss_desc, cudnn_workspace,
workspace_size));
};
workspace_handle.RunFunc(cudnn_func, workspace_size);
warpctc_label_lengths.data(), warpctc_logits_lengths.data(), loss_data,
cu_grad_desc, warpctc_grad_data, CUDNN_CTC_LOSS_ALGO_DETERMINISTIC,
cu_ctcloss_desc, cudnn_workspace, workspace_size));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册