diff --git a/paddle/fluid/framework/conv_search_cache.h b/paddle/fluid/framework/conv_search_cache.h new file mode 100644 index 0000000000000000000000000000000000000000..720467d6f1cda14a4fbb3ffce9afbc6f93d7fa08 --- /dev/null +++ b/paddle/fluid/framework/conv_search_cache.h @@ -0,0 +1,62 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator_kernel_configs.h" +#include "paddle/fluid/platform/cudnn_helper.h" + +namespace paddle { +namespace framework { + +using framework::AlgorithmsCache; + +// ConvSearchCache using framework::AlgorithmsCache to search +// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or +// cudnnConvolutionBwdFilterAlgo_t +class ConvSearchCache { + public: + static ConvSearchCache& Instance() { + static ConvSearchCache instance; + return instance; + } + + AlgorithmsCache* GetForward() { + return &forward_cache_; + } + AlgorithmsCache* GetBackwardData() { + return &backward_data_cache_; + } + AlgorithmsCache* GetBackwardFilter() { + return &backward_filter_cache_; + } + AlgorithmsCache* GetConvFusion() { + return &fusion_forward_cache_; + } + + private: + ConvSearchCache() {} + ~ConvSearchCache() {} + ConvSearchCache(const ConvSearchCache&) {} + ConvSearchCache& operator=(const ConvSearchCache&) {} + + AlgorithmsCache forward_cache_; + AlgorithmsCache backward_data_cache_; + AlgorithmsCache backward_filter_cache_; + AlgorithmsCache fusion_forward_cache_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/operators/conv_cudnn_helper.h b/paddle/fluid/operators/conv_cudnn_helper.h index df94de8a18f60bc3a93eddfda2617eaba568d47a..d20311d091ce48311693656f7b741ea769ef91db 100644 --- a/paddle/fluid/operators/conv_cudnn_helper.h +++ b/paddle/fluid/operators/conv_cudnn_helper.h @@ -18,10 +18,10 @@ limitations under the License. */ #include #include #include +#include "paddle/fluid/framework/conv_search_cache.h" #include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/platform/cudnn_desc.h" -// #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace operators { @@ -90,43 +90,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector& v) { return out; } -// ConvSearchCache using framework::AlgorithmsCache to search -// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or -// cudnnConvolutionBwdFilterAlgo_t -class ConvSearchCache { - public: - static ConvSearchCache& Instance() { - static ConvSearchCache instance; - return instance; - } - - framework::AlgorithmsCache* GetForward() { - return &forward_cache_; - } - framework::AlgorithmsCache* GetBackwardData() { - return &backward_data_cache_; - } - framework::AlgorithmsCache* - GetBackwardFilter() { - return &backward_filter_cache_; - } - framework::AlgorithmsCache* GetConvFusion() { - return &fusion_forward_cache_; - } - - private: - ConvSearchCache() {} - ~ConvSearchCache() {} - ConvSearchCache(const ConvSearchCache&) {} - ConvSearchCache& operator=(const ConvSearchCache&) {} - - framework::AlgorithmsCache forward_cache_; - framework::AlgorithmsCache - backward_data_cache_; - framework::AlgorithmsCache - backward_filter_cache_; - framework::AlgorithmsCache fusion_forward_cache_; -}; +using framework::ConvSearchCache; struct ConvArgs { cudnnHandle_t handle; @@ -228,7 +192,7 @@ struct SearchAlgorithm { auto& temp = ctx.cuda_device_context(); AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetForward()); + *(framework::ConvSearchCache::Instance().GetForward()); auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); @@ -367,7 +331,7 @@ struct SearchAlgorithm { auto workspace_handle = dev_ctx.cudnn_workspace_handle(); AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetBackwardData()); + *(framework::ConvSearchCache::Instance().GetBackwardData()); auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); @@ -495,7 +459,7 @@ struct SearchAlgorithm { ctx.template device_context(); auto workspace_handle = dev_ctx.cudnn_workspace_handle(); AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetBackwardFilter()); + *(framework::ConvSearchCache::Instance().GetBackwardFilter()); auto x_dims = framework::vectorize(args.x->dims()); auto w_dims = framework::vectorize(args.w->dims()); diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cu b/paddle/fluid/operators/fused/conv_fusion_op.cu index 92769bb93eae36a172c5169a699c6ac3fa2f0649..c4844300eaa67791d0c5a9edef408835e2648556 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cu +++ b/paddle/fluid/operators/fused/conv_fusion_op.cu @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include +#include "paddle/fluid/framework/conv_search_cache.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/conv_cudnn_helper.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/math/padding.h" +#include "paddle/fluid/platform/cudnn_helper.h" DECLARE_int64(cudnn_exhaustive_search_times); @@ -32,6 +33,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; using DataLayout = platform::DataLayout; using framework::AlgorithmsCache; +using framework::ConvSearchCache; template using ScalingParamType = typename platform::CudnnDataType::ScalingParamType; @@ -233,7 +235,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel { return fwd_perf_stat[0].algo; }; AlgorithmsCache& algo_cache = - *(ConvSearchCache::Instance().GetConvFusion()); + *(framework::ConvSearchCache::Instance().GetConvFusion()); int search_times = ctx.Attr("search_times"); search_times = std::max( static_cast(FLAGS_cudnn_exhaustive_search_times), search_times);