未验证 提交 76d78c63 编写于 作者: Z Zhou Wei 提交者: GitHub

fix conv_fusion_op conflict,test=develop (#24020)

上级 931cba2e
...@@ -167,13 +167,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -167,13 +167,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
conv_desc.descriptor<T>(padding_common, strides, dilations); conv_desc.descriptor<T>(padding_common, strides, dilations);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc, platform::dynload::cudnnSetConvolutionGroupCount(cudnn_conv_desc,
groups), groups));
platform::errors::External(
"Call of cudnnSetConvolutionGroupCount(cudnn_conv_desc, groups) "
"failed, where cudnn_conv_desc is configured: padding = [%s], "
"strides = [%s], dilations = [%s]; groups = %d",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations), groups));
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>( cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor<T>(
layout, framework::vectorize<int>(transformed_input.dims())); layout, framework::vectorize<int>(transformed_input.dims()));
...@@ -204,15 +198,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -204,15 +198,8 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSetConvolutionMathType(
platform::dynload::cudnnSetConvolutionMathType(cudnn_conv_desc, cudnn_conv_desc, CUDNN_DEFAULT_MATH));
CUDNN_DEFAULT_MATH),
platform::errors::External(
"Call of cudnnSetConvolutionMathType(cudnn_conv_desc, "
"CUDNN_DEFAULT_MATH) failed, where cudnn_conv_desc is configured: "
"padding = %d, strides = %d, dilations = %d.",
framework::make_ddim(padding_common), framework::make_ddim(strides),
framework::make_ddim(dilations)));
auto x_dims = framework::vectorize(transformed_input.dims()); auto x_dims = framework::vectorize(transformed_input.dims());
auto f_dims = framework::vectorize(filter->dims()); auto f_dims = framework::vectorize(filter->dims());
...@@ -221,9 +208,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -221,9 +208,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
platform::dynload::cudnnGetConvolutionForwardAlgorithm( platform::dynload::cudnnGetConvolutionForwardAlgorithm(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, cudnn_output_desc, CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
workspace_size_limit, &algo), workspace_size_limit, &algo));
platform::errors::External(
"Call of cudnnGetConvolutionForwardAlgorithm failed."));
VLOG(3) << "cuDNN forward algo " << algo; VLOG(3) << "cuDNN forward algo " << algo;
} else { } else {
std::function<cudnnConvolutionFwdAlgo_t()> search_func = std::function<cudnnConvolutionFwdAlgo_t()> search_func =
...@@ -237,9 +222,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -237,9 +222,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
handle, cudnn_input_desc, input_data, cudnn_filter_desc, handle, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, cudnn_output_desc, output_data, filter_data, cudnn_conv_desc, cudnn_output_desc, output_data,
kNUM_CUDNN_FWD_ALGS, &returned_algo_count, kNUM_CUDNN_FWD_ALGS, &returned_algo_count,
fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit), fwd_perf_stat.data(), cudnn_workspace, workspace_size_limit));
platform::errors::External(
"Call of cudnnFindConvolutionForwardAlgorithmEx failed."));
}; };
workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit); workspace_handle.RunFuncSync(cudnn_find_func, workspace_size_limit);
VLOG(3) << "Perf result: (algo: stat, time, memory)"; VLOG(3) << "Perf result: (algo: stat, time, memory)";
...@@ -273,9 +256,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -273,9 +256,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(
platform::dynload::cudnnGetConvolutionForwardWorkspaceSize( platform::dynload::cudnnGetConvolutionForwardWorkspaceSize(
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc, handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes), cudnn_output_desc, algo, &workspace_size_in_bytes));
platform::errors::External(
"Call of cudnnGetConvolutionForwardWorkspaceSize failed."));
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
workspace_size_in_bytes, workspace_size_limit, workspace_size_in_bytes, workspace_size_limit,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -292,20 +273,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -292,20 +273,15 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
// ------------- cudnn conv forward and bias add --------------------- // ------------- cudnn conv forward and bias add ---------------------
ScalingParamType<T> alpha = 1.0f, beta = 0.0f; ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
auto cudnn_func = [&](void* cudnn_workspace) { auto cudnn_func = [&](void* cudnn_workspace) {
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnConvolutionForward(
platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc, handle, &alpha, cudnn_input_desc, input_data, cudnn_filter_desc,
filter_data, cudnn_conv_desc, algo, cudnn_workspace, filter_data, cudnn_conv_desc, algo, cudnn_workspace,
workspace_size_in_bytes, &beta, cudnn_output_desc, output_data), workspace_size_in_bytes, &beta, cudnn_output_desc, output_data));
platform::errors::External(
"Call of cudnnConvolutionForward failed."));
}; };
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
PADDLE_ENFORCE_CUDA_SUCCESS( PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnAddTensor(
platform::dynload::cudnnAddTensor(handle, &alpha, cudnn_bias_desc, handle, &alpha, cudnn_bias_desc, bias_data, &alpha, cudnn_output_desc,
bias_data, &alpha, output_data));
cudnn_output_desc, output_data),
platform::errors::External("Call of cudnnAddTensor failed."));
} else { } else {
if (activation == "identity") { if (activation == "identity") {
algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
...@@ -320,9 +296,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -320,9 +296,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
cudnn_filter_desc, filter_data, cudnn_conv_desc, algo, cudnn_filter_desc, filter_data, cudnn_conv_desc, algo,
cudnn_workspace, workspace_size_in_bytes, &alpha2, cudnn_workspace, workspace_size_in_bytes, &alpha2,
cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data, cudnn_output_desc, residual_data, cudnn_bias_desc, bias_data,
cudnn_act_desc, cudnn_output_desc, output_data), cudnn_act_desc, cudnn_output_desc, output_data));
platform::errors::External(
"Call of cudnnConvolutionBiasActivationForward failed."));
}; };
workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes); workspace_handle.RunFunc(cudnn_func, workspace_size_in_bytes);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册