未验证 提交 b4b6763a 编写于 作者: Z zhongpu 提交者: GitHub

fix bug for exhaustive_search in conv_fusion_op, test=develop (#23727)

上级 d77bc12e
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/cudnn_helper.h"
namespace paddle {
namespace framework {
using framework::AlgorithmsCache;
// ConvSearchCache using framework::AlgorithmsCache to search
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>* GetBackwardFilter() {
return &backward_filter_cache_;
}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t> backward_data_cache_;
AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t> backward_filter_cache_;
AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
} // namespace framework
} // namespace paddle
......@@ -18,10 +18,10 @@ limitations under the License. */
#include <array>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/cudnn_desc.h"
// #include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
......@@ -90,43 +90,7 @@ std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
return out;
}
// ConvSearchCache using framework::AlgorithmsCache to search
// cudnnConvolutionFwdAlgo_t, cudnnConvolutionBwdDataAlgo_t or
// cudnnConvolutionBwdFilterAlgo_t
class ConvSearchCache {
public:
static ConvSearchCache& Instance() {
static ConvSearchCache instance;
return instance;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetForward() {
return &forward_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>* GetBackwardData() {
return &backward_data_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>*
GetBackwardFilter() {
return &backward_filter_cache_;
}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t>* GetConvFusion() {
return &fusion_forward_cache_;
}
private:
ConvSearchCache() {}
~ConvSearchCache() {}
ConvSearchCache(const ConvSearchCache&) {}
ConvSearchCache& operator=(const ConvSearchCache&) {}
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdDataAlgo_t>
backward_data_cache_;
framework::AlgorithmsCache<cudnnConvolutionBwdFilterAlgo_t>
backward_filter_cache_;
framework::AlgorithmsCache<cudnnConvolutionFwdAlgo_t> fusion_forward_cache_;
};
using framework::ConvSearchCache;
struct ConvArgs {
cudnnHandle_t handle;
......@@ -228,7 +192,7 @@ struct SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t> {
auto& temp = ctx.cuda_device_context();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetForward());
*(framework::ConvSearchCache::Instance().GetForward());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
......@@ -367,7 +331,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t> {
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardData());
*(framework::ConvSearchCache::Instance().GetBackwardData());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
......@@ -495,7 +459,7 @@ struct SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t> {
ctx.template device_context<platform::CUDADeviceContext>();
auto workspace_handle = dev_ctx.cudnn_workspace_handle();
AlgorithmsCache<algo_t>& algo_cache =
*(ConvSearchCache::Instance().GetBackwardFilter());
*(framework::ConvSearchCache::Instance().GetBackwardFilter());
auto x_dims = framework::vectorize(args.x->dims());
auto w_dims = framework::vectorize(args.w->dims());
......
......@@ -13,11 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <array>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/conv_cudnn_helper.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/operators/math/padding.h"
#include "paddle/fluid/platform/cudnn_helper.h"
DECLARE_int64(cudnn_exhaustive_search_times);
......@@ -32,6 +33,7 @@ using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using ScopedActivationDescriptor = platform::ScopedActivationDescriptor;
using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
using framework::ConvSearchCache;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
......@@ -233,7 +235,7 @@ class CUDNNConvFusionOpKernel : public framework::OpKernel<T> {
return fwd_perf_stat[0].algo;
};
AlgorithmsCache<cudnnConvolutionFwdAlgo_t>& algo_cache =
*(ConvSearchCache::Instance().GetConvFusion());
*(framework::ConvSearchCache::Instance().GetConvFusion());
int search_times = ctx.Attr<int>("search_times");
search_times = std::max(
static_cast<int>(FLAGS_cudnn_exhaustive_search_times), search_times);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册