未验证 提交 2567afa3 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #12462 from reyoung/feature/fix_cudnn_deterministic

Fix bug in cudnn_determistic
...@@ -20,10 +20,10 @@ limitations under the License. */ ...@@ -20,10 +20,10 @@ limitations under the License. */
#include "paddle/fluid/platform/cudnn_helper.h" #include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
DEFINE_bool(cudnn_deterministic, true, DEFINE_bool(cudnn_deterministic, false,
"Whether allow using an autotuning algorithm for convolution " "Whether allow using an autotuning algorithm for convolution "
"operator. The autotuning algorithm may be non-deterministic. If " "operator. The autotuning algorithm may be non-deterministic. If "
"false, the algorithm is deterministic."); "true, the algorithm is deterministic.");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -272,7 +272,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -272,7 +272,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>(); auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto handle = dev_ctx.cudnn_handle(); auto handle = dev_ctx.cudnn_handle();
if (input_grad) { if (input_grad) {
if (FLAGS_cudnn_deterministic) { if (!FLAGS_cudnn_deterministic) {
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm( platform::dynload::cudnnGetConvolutionBackwardDataAlgorithm(
handle, cudnn_filter_desc, handle, cudnn_filter_desc,
...@@ -297,7 +297,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -297,7 +297,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
} }
if (filter_grad) { if (filter_grad) {
if (FLAGS_cudnn_deterministic) { if (!FLAGS_cudnn_deterministic) {
CUDNN_ENFORCE( CUDNN_ENFORCE(
platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm( platform::dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
handle, cudnn_input_desc, cudnn_output_grad_desc, handle, cudnn_input_desc, cudnn_output_grad_desc,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册