未验证 提交 b4b6763a 编写于 作者: Z zhongpu 提交者: GitHub

fix bug for exhaustive_search in conv_fusion_op, test=develop (#23727)

上级 d77bc12e
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace framework {
using framework::AlgorithmsCache;
// ConvSearchCache using framework::AlgorithmsCache to search
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* GetBackwardFilter() {
return &backward_filter_cache_;
}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t> backward_data_cache_;
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t> backward_filter_cache_;
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
} // namespace framework
} // namespace paddle
...@@ -18,10 +18,10 @@ limitations under the License. */ ...@@ -18,10 +18,10 @@ limitations under the License. */
#include <array> #include <array>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h" #include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -90,43 +90,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { ...@@ -90,43 +90,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out; return out;
} }
// ConvSearchCache using framework::AlgorithmsCache to search using framework::ConvSearchCache;
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
GetBackwardFilter() {
return &backward_filter_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
backward_data_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
backward_filter_cache_;
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
struct ConvArgs { struct ConvArgs {
cudnnHandle_t handle; cudnnHandle_t handle;
...@@ -228,7 +192,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> { ...@@ -228,7 +192,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
auto& temp = ctx.cuda_device_context(); auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache = AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetForward()); *(framework::ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
...@@ -367,7 +331,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> { ...@@ -367,7 +331,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache = AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardData()); *(framework::ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
...@@ -495,7 +459,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -495,7 +459,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
ctx.template device_context<platform::CUDADeviceContext>(); ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle(); auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache = AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardFilter()); *(framework::ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims()); auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims()); auto w_dims = framework::vectorize(args.w->dims());
......
...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <array> #include <array>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h" #include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h" #include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times); DECLARE_int64(cudnn_exhaustive_search_times);
...@@ -32,6 +33,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; ...@@ -32,6 +33,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor; using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache; using framework::AlgorithmsCache;
using framework::ConvSearchCache;
template <typename T> template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType; using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
...@@ -233,7 +235,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> { ...@@ -233,7 +235,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return fwd_perf_stat[0].algo; return fwd_perf_stat[0].algo;
}; };
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache = AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
*(ConvSearchCache::Instance().GetConvFusion()); *(framework::ConvSearchCache::Instance().GetConvFusion());
int search_times = ctx.Attr<int>("search_times"); int search_times = ctx.Attr<int>("search_times");
search_times = std::max( search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times); static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册