未验证 提交 76d78c63 编写于 作者: Z Zhou Wei 提交者: GitHub

fix conv_fusion_op conflict,test=develop (#24020)

上级 931cba2e
......@@ -167,13 +167,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc,
groups),
platform::errors::External(
"Call of cudnnSetConvolutionGroupCount(cudnn_conv_desc, groups) "
"failed, where cudnn_conv_desc is configured: padding = [%s], "
"strides = [%s], dilations = [%s]; groups = %d",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations), groups));
groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims()));
......@@ -204,15 +198,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc,
CUDNN_DEFAULT_MATH),
platform::errors::External(
"Call of cudnnSetConvolutionMathType(cudnn_conv_desc, "
"CUDNN_DEFAULT_MATH) failed, where cudnn_conv_desc is configured: "
"padding = %d, strides = %d, dilations = %d.",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations)));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
cudnn_conv_desc, CUDNN_DEFAULT_MATH));
auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims());
......@@ -221,9 +208,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo),
platform::errors::External(
"Call of cudnnGetConvolutionForwardAlgorithm failed."));
workspace_size_limit, &algo));
VLOG(3) << "cuDNN forward algo " << algo;
} else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func =
......@@ -237,9 +222,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit),
platform::errors::External(
"Call of cudnnFindConvolutionForwardAlgorithmEx failed."));
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
};
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)";
......@@ -273,9 +256,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes),
platform::errors::External(
"Call of cudnnGetConvolutionForwardWorkspaceSize failed."));
cudnn_output_desc, algo, &workspace_size_in_bytes));
PADDLE_ENFORCE_LE(
workspace_size_in_bytes, workspace_size_limit,
platform::errors::InvalidArgument(
......@@ -292,20 +273,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data),
platform::errors::External(
"Call of cudnnConvolutionForward failed."));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnAddTensor(handle, &alpha, cudnn_bias_desc,
bias_data, &alpha,
cudnn_output_desc, output_data),
platform::errors::External("Call of cudnnAddTensor failed."));
PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnAddTensor(
handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
output_data));
} else {
if (activation == "identity") {
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
......@@ -320,9 +296,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc, filter_data, cudnn_conv_desc, algo,
cudnn_workspace, workspace_size_in_bytes, &alpha2,
cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data,
cudnn_act_desc, cudnn_output_desc, output_data),
platform::errors::External(
"Call of cudnnConvolutionBiasActivationForward failed."));
cudnn_act_desc, cudnn_output_desc, output_data));
};
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册