未验证 提交 b937cdc5 编写于 作者: L limingshu 提交者: GitHub

Autotune the workspace_size_limit in conv. (#40338)

* Using the maximum workspace_size of all alogirhms to limit the workspace size in exhaustive search mode.

* Use the system cudaMalloc and cudaFree to allocate workspace during searching.

* Enable switch of two kind of workspace setting methods.
Co-authored-by: NLiu Yiqun <liuyiqun01@baidu.com>
上级 e1792a31
......@@ -16,7 +16,6 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
namespace paddle {
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = platform::DataLayout;
using framework::AlgorithmsCache;
using framework::ConvSearchCache;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
// As the basic for SearchAlgorithm struct.
template <typename PerfT>
struct SearchAlgorithm {};
// As the container of searchAlgorithm::Find() result.
template <typename AlgoT>
struct SearchResult {
public:
AlgoT algo = static_cast<AlgoT>(0);
float time = -1.f;
size_t workspace_size = 0;
};
// As the container of conv relevant descriptors.
template <typename HandleT, typename DataT>
struct ConvArgsBase {
HandleT handle;
platform::TensorDescriptor idesc, odesc;
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
DataT cudnn_dtype;
// strides
std::vector<int> s;
// paddings
std::vector<int> p;
// dilations
std::vector<int> d;
ConvArgsBase(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d, DataT dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
};
static inline void GetNCDHW(const framework::DDim& dims,
const DataLayout& layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
*C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
int i = layout == DataLayout::kNCHW ? 0 : 1;
if (dims.size() == 5) {
*D = dims[2 - i];
*H = dims[3 - i];
*W = dims[4 - i];
} else {
*D = 1;
*H = dims[2 - i];
*W = dims[3 - i];
}
}
template <typename T>
static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
} // namespace operators
} // namespace paddle
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_int64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
DECLARE_int64(cudnn_exhaustive_search_times);
......
......@@ -14,42 +14,12 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <array>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/conv_search_cache.h"
#include "paddle/fluid/framework/operator_kernel_configs.h"
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/fluid/operators/conv_base_helper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
using framework::AlgorithmsCache;
static inline void GetNCDHW(const framework::DDim& dims,
const DataLayout& layout, int* N, int* C, int* D,
int* H, int* W) {
*N = dims[0];
*C = layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
int i = layout == DataLayout::kNCHW ? 0 : 1;
if (dims.size() == 5) {
*D = dims[2 - i];
*H = dims[3 - i];
*W = dims[4 - i];
} else {
*D = 1;
*H = dims[2 - i];
*W = dims[3 - i];
}
}
using ConvArgs = ConvArgsBase<miopenHandle_t, miopenDataType_t>;
template <typename DeviceContext, typename T, size_t D>
static void RemovePaddingSlice(const phi::GPUContext& context,
......@@ -66,9 +36,8 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
extents[i] = new_out_dims[i];
}
int start;
for (size_t i = 0; i < axes.size(); ++i) {
start = starts[i];
int start = starts[i];
if (start < 0) {
start = (start + in_dims[axes[i]]);
}
......@@ -85,41 +54,6 @@ static void RemovePaddingSlice(const phi::GPUContext& context,
out_t.device(place) = in_t.slice(offsets, extents);
}
template <typename T>
std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
out << "[";
for (auto const& tmp : v) out << tmp << ",";
out << "]";
return out;
}
using framework::ConvSearchCache;
struct ConvArgs {
miopenHandle_t handle;
platform::TensorDescriptor idesc, odesc;
platform::FilterDescriptor wdesc;
platform::ConvolutionDescriptor cdesc;
const framework::Tensor *x, *w, *o;
miopenDataType_t cudnn_dtype;
// strides
std::vector<int> s;
// paddings
std::vector<int> p;
// dilations
std::vector<int> d;
ConvArgs(const framework::Tensor* x, const framework::Tensor* w,
const framework::Tensor* o, const std::vector<int> s,
const std::vector<int> p, const std::vector<int> d,
miopenDataType_t dtype)
: x(x), w(w), o(o), s(s), p(p), d(d), cudnn_dtype(dtype) {}
};
template <typename algo_t>
struct SearchAlgorithm {};
template <>
struct SearchAlgorithm<miopenConvFwdAlgorithm_t> {
using perf_t = miopenConvAlgoPerf_t;
......
......@@ -16,8 +16,6 @@ limitations under the License. */
#include "paddle/fluid/operators/conv_cudnn_op_cache.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
DECLARE_uint64(conv_workspace_size_limit);
namespace paddle {
namespace operators {
......
......@@ -188,6 +188,8 @@ class RecordedGpuMallocHelper {
if (UNLIKELY(malloc_managed_memory)) {
result = cudaMallocManaged(ptr, size);
} else {
VLOG(10) << "[cudaMalloc] size=" << static_cast<double>(size) / (1 << 20)
<< " MB";
result = cudaMalloc(ptr, size);
}
#endif
......@@ -226,6 +228,8 @@ class RecordedGpuMallocHelper {
if (err != hipErrorDeinitialized) {
#else
auto err = cudaFree(ptr);
VLOG(10) << "[cudaFree] size=" << static_cast<double>(size) / (1 << 20)
<< " MB";
if (err != cudaErrorCudartUnloading) {
#endif
PADDLE_ENFORCE_GPU_SUCCESS(err);
......
......@@ -522,8 +522,8 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : phi::GPUContext(place) {
cuda_stream_.reset(new stream::CUDAStream(phi::GPUContext::stream(), place));
auto& instance = memory::allocation::AllocatorFacade::Instance();
instance.SetDefaultStream(place, phi::GPUContext::stream());
workspace_.reset(
new phi::DnnWorkspaceHandle(instance.GetAllocator(place).get()));
workspace_.reset(new phi::DnnWorkspaceHandle(
instance.GetAllocator(place).get(), stream()));
}
CUDADeviceContext::~CUDADeviceContext() = default;
......@@ -623,7 +623,8 @@ phi::DnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
return phi::DnnWorkspaceHandle(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(GetPlace())
.get());
.get(),
stream());
}
return phi::GPUContext::cudnn_workspace_handle();
}
......
......@@ -161,10 +161,9 @@ PADDLE_DEFINE_EXPORTED_bool(
* increased.
* Users need to balance memory and speed.
*/
PADDLE_DEFINE_EXPORTED_uint64(
conv_workspace_size_limit,
paddle::platform::kDefaultConvWorkspaceSizeLimitMB,
"cuDNN convolution workspace limit in MB unit.");
PADDLE_DEFINE_EXPORTED_int64(conv_workspace_size_limit,
paddle::platform::kDefaultConvWorkspaceSizeLimitMB,
"cuDNN convolution workspace limit in MB unit.");
/**
* CUDNN related FLAG
......
......@@ -12,6 +12,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include <algorithm>
#include <array>
......@@ -155,6 +156,39 @@ static void StreamCallbackFunc(gpuStream_t stream,
} // namespace internal
void DnnWorkspaceHandle::RunFuncSync(
const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
bool use_cached_allocation) {
bool need_realloc = required_workspace_bytes > WorkspaceSize();
if (need_realloc && !use_cached_allocation) {
void* workspace_ptr = nullptr;
size_t size = ((required_workspace_bytes + 255) >> 8) << 8;
std::lock_guard<std::mutex> guard(*mtx_);
#ifdef PADDLE_WITH_HIP
auto status = hipMalloc(&workspace_ptr, size);
#else
auto status = cudaMalloc(&workspace_ptr, size);
#endif
if (status == gpuSuccess) {
cudnn_func(workspace_ptr);
phi::backends::gpu::GpuStreamSync(stream_);
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipFree(workspace_ptr));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaFree(workspace_ptr));
#endif
return;
}
}
RunFunc(cudnn_func, required_workspace_bytes);
if (need_realloc) {
// Release the workspace allocated in this running.
ResetWorkspace();
}
}
void DnnWorkspaceHandle::ResetWorkspace() { allocation_ = nullptr; }
void DnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
......@@ -295,13 +329,13 @@ struct GPUContext::Impl {
void InitDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
workspace_ = new DnnWorkspaceHandle(allocator_);
workspace_ = new DnnWorkspaceHandle(allocator_, stream_);
}
void DestoryInternalWorkspace() {
if (owned_ && workspace_ != nullptr) {
delete workspace_;
stream_ = nullptr;
workspace_ = nullptr;
}
}
......@@ -313,7 +347,7 @@ struct GPUContext::Impl {
DnnWorkspaceHandle GetDnnWorkspace() {
PD_CHECK(allocator_ != nullptr,
"the device allocator for gpu context is nullptr.");
return DnnWorkspaceHandle(allocator_);
return DnnWorkspaceHandle(allocator_, stream_);
}
void InitStream() {
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/forwards.h"
#include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_helper.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/device_context.h"
......@@ -28,8 +29,8 @@ namespace phi {
class DnnWorkspaceHandle {
public:
explicit inline DnnWorkspaceHandle(Allocator* allocator)
: allocator_(allocator) {
inline DnnWorkspaceHandle(Allocator* allocator, gpuStream_t stream)
: allocator_(allocator), stream_(stream) {
mtx_.reset(new std::mutex());
}
......@@ -48,11 +49,9 @@ class DnnWorkspaceHandle {
* running the function. Currently this function is only used when cudnn
* exhaustive searching and callers have to guarantee that the input function
* is host blocking */
inline void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes) {
RunFunc(cudnn_func, required_workspace_bytes);
ResetWorkspace();
}
void RunFuncSync(const std::function<void(void*)>& cudnn_func,
size_t required_workspace_bytes,
bool use_cached_allocation = true);
inline size_t WorkspaceSize() {
if (allocation_ == nullptr) {
......@@ -70,7 +69,8 @@ class DnnWorkspaceHandle {
private:
Allocator::AllocationPtr allocation_{nullptr};
Allocator* allocator_{nullptr};
Allocator* allocator_{nullptr}; // Not owned
gpuStream_t stream_{nullptr}; // Not owned
std::unique_ptr<std::mutex> mtx_;
};
......
if (WITH_GPU)
nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
nv_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
nv_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
elseif (WITH_ROCM)
hip_test(gpu_timer_test SRCS gpu_timer_test.cu DEPS gtest)
hip_test(auto_tune_test SRCS auto_tune_test.cu DEPS gtest)
......
......@@ -289,21 +289,17 @@ void ConvCudnnGradGradKernel(
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t fwd_algo1 = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvFwdAlgorithm_t fwd_algo2 = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvBwdDataAlgorithm_t data_algo =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result1;
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result2;
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> data_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionFwdAlgo_t fwd_algo1 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionFwdAlgo_t fwd_algo2 =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result1;
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result2;
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> data_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
auto layout = paddle::platform::GetCudnnTensorFormat(
......@@ -332,13 +328,13 @@ void ConvCudnnGradGradKernel(
using search1 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
fwd_algo1 = search1::Find<T>(
fwd_result1.algo = search1::Find<T>(
args1, exhaustive_search, false, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_algo1);
fwd_result1 = search1::Find<T>(args1, exhaustive_search, false, ctx);
workspace_size = search1::GetWorkspaceSize(args1, fwd_result1.algo);
#endif
}
......@@ -360,14 +356,14 @@ void ConvCudnnGradGradKernel(
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2));
fwd_algo2 = search2::Find<T>(
fwd_result2.algo = search2::Find<T>(
args2, exhaustive_search, false, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
fwd_algo2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, fwd_algo2));
fwd_result2 = search2::Find<T>(args2, exhaustive_search, false, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, fwd_result2.algo));
#endif
}
}
......@@ -389,15 +385,15 @@ void ConvCudnnGradGradKernel(
using search3 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo = search3::Find<T>(
filter_result.algo = search3::Find<T>(
args3, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
filter_result =
search3::Find<T>(args3, exhaustive_search, deterministic, ctx);
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo));
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
......@@ -419,14 +415,15 @@ void ConvCudnnGradGradKernel(
using search4 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo = search4::Find<T>(
data_result.algo = search4::Find<T>(
args4, exhaustive_search, deterministic, workspace_size, ctx);
#else
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
data_result =
search4::Find<T>(args4, exhaustive_search, deterministic, ctx);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, data_result.algo));
#endif
}
......@@ -471,7 +468,7 @@ void ConvCudnnGradGradKernel(
args1.wdesc.desc(),
w,
args1.cdesc.desc(),
fwd_algo1,
fwd_result1.algo,
&beta,
args1.odesc.desc(),
transformed_ddy_channel,
......@@ -492,7 +489,7 @@ void ConvCudnnGradGradKernel(
args1.wdesc.desc(),
w + i * group_offset_filter,
args1.cdesc.desc(),
fwd_algo1,
fwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -517,7 +514,7 @@ void ConvCudnnGradGradKernel(
args2.wdesc.desc(),
ddw,
args2.cdesc.desc(),
fwd_algo2,
fwd_result2.algo,
&beta,
args2.odesc.desc(),
transformed_ddy_channel,
......@@ -538,7 +535,7 @@ void ConvCudnnGradGradKernel(
args2.wdesc.desc(),
ddw + i * group_offset_filter,
args2.cdesc.desc(),
fwd_algo2,
fwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
......@@ -568,7 +565,7 @@ void ConvCudnnGradGradKernel(
args3.idesc.desc(),
ddx,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args3.wdesc.desc(),
dw,
......@@ -589,7 +586,7 @@ void ConvCudnnGradGradKernel(
args3.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -615,7 +612,7 @@ void ConvCudnnGradGradKernel(
args4.wdesc.desc(),
ddw,
args4.cdesc.desc(),
data_algo,
data_result.algo,
&beta,
args4.idesc.desc(),
transformed_dx,
......@@ -636,7 +633,7 @@ void ConvCudnnGradGradKernel(
args4.odesc.desc(),
transformed_dy_channel + i * group_offset_out,
args4.cdesc.desc(),
data_algo,
data_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -322,17 +322,16 @@ void ConvCudnnGradKernel(const Context& ctx,
int group_offset_in = i_c / groups * i_h * i_w * i_d;
int group_offset_out = o_c / groups * o_h * o_w * o_d;
int group_offset_filter = transformed_filter_channel.numel() / groups;
// ------------------- cudnn backward algorithm ---------------------
#ifdef PADDLE_WITH_HIP
miopenConvBwdDataAlgorithm_t data_algo =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionBwdDataAlgo_t data_algo =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
// input data workspace_size
size_t workspace_size_d = 0;
......@@ -368,14 +367,14 @@ void ConvCudnnGradKernel(const Context& ctx,
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size_d =
std::max(workspace_size_d, search1::GetWorkspaceSize(args1));
data_algo = search1::Find<T>(
bwd_result.algo = search1::Find<T>(
args1, exhaustive_search, deterministic, workspace_size_d, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
data_algo = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size_d =
std::max(workspace_size_d, search1::GetWorkspaceSize(args1, data_algo));
bwd_result = search1::Find<T>(args1, exhaustive_search, deterministic, ctx);
workspace_size_d = std::max(
workspace_size_d, search1::GetWorkspaceSize(args1, bwd_result.algo));
#endif
}
......@@ -397,15 +396,17 @@ void ConvCudnnGradKernel(const Context& ctx,
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size_w =
std::max(workspace_size_w, search2::GetWorkspaceSize(args2));
filter_algo = search2::Find<T>(
filter_result.algo = search2::Find<T>(
args2, exhaustive_search, deterministic, workspace_size_w, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo =
filter_result =
search2::Find<T>(args2, exhaustive_search, deterministic, ctx);
workspace_size_w = std::max(workspace_size_w,
search2::GetWorkspaceSize(args2, filter_algo));
VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time;
workspace_size_w = std::max(
workspace_size_w, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif
}
......@@ -439,7 +440,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
&beta,
args1.idesc.desc(),
temp_tensor_data,
......@@ -471,7 +472,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
&beta,
args1.idesc.desc(),
transformed_input_grad_data,
......@@ -494,7 +495,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.odesc.desc(),
output_grad_data + i * group_offset_out,
args1.cdesc.desc(),
data_algo,
bwd_result.algo,
cudnn_workspace_ptr,
workspace_size_d,
&beta,
......@@ -554,7 +555,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.idesc.desc(),
input_data,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args2.wdesc.desc(),
filter_grad_data,
......@@ -575,7 +576,7 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.odesc.desc(),
output_grad_data + i * group_offset_out,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
cudnn_workspace_ptr,
workspace_size_w,
&beta_filter,
......
......@@ -18,7 +18,6 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/eigen.h"
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/operators/conv_miopen_helper.h"
#else
......@@ -68,7 +67,6 @@ void ConvCudnnKernel(const Context& ctx,
"FLAGS_cudnn_deterministic True at same time."));
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
auto dtype = paddle::platform::CudnnDataType<T>::type;
#ifdef PADDLE_WITH_HIP
......@@ -309,17 +307,17 @@ void ConvCudnnKernel(const Context& ctx,
size_t workspace_size = 0; // final workspace to allocate.
// ------------------- cudnn conv algorithm ---------------------
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t algo{};
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
using search = paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = search::GetWorkspaceSize(args);
algo = search::Find<T>(
fwd_result.algo = search::Find<T>(
args, exhaustive_search, deterministic, workspace_size, ctx);
#else
cudnnConvolutionFwdAlgo_t algo{};
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
algo = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, algo);
fwd_result = search::Find<T>(args, exhaustive_search, deterministic, ctx);
workspace_size = search::GetWorkspaceSize(args, fwd_result.algo);
#endif
#if defined(PADDLE_WITH_CUDA) && CUDNN_VERSION_MIN(7, 0, 1)
......@@ -328,7 +326,7 @@ void ConvCudnnKernel(const Context& ctx,
// in forward computation, so change the algorithm to CUDNN_CONVOLUTION_\
// FWD_ALGO_IMPLICIT_GEMM manually.
if (groups > 1) {
algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
fwd_result.algo = static_cast<cudnnConvolutionFwdAlgo_t>(0);
}
#endif
......@@ -352,7 +350,7 @@ void ConvCudnnKernel(const Context& ctx,
args.wdesc.desc(),
filter_data,
args.cdesc.desc(),
algo,
fwd_result.algo,
&beta,
args.odesc.desc(),
output_data,
......@@ -373,7 +371,7 @@ void ConvCudnnKernel(const Context& ctx,
args.wdesc.desc(),
filter_data + i * group_offset_filter,
args.cdesc.desc(),
algo,
fwd_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -188,11 +188,13 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvFwdAlgorithm_t data_algo{};
miopenConvBwdWeightsAlgorithm_t filter_algo{};
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
#else
cudnnConvolutionFwdAlgo_t data_algo{};
cudnnConvolutionBwdFilterAlgo_t filter_algo{};
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
#endif
auto layout_tensor = paddle::platform::GetCudnnTensorFormat(layout);
......@@ -218,14 +220,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
using search1 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1));
data_algo =
fwd_result.algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search1::GetWorkspaceSize(args1, data_algo));
fwd_result = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search1::GetWorkspaceSize(args1, fwd_result.algo));
#endif
}
......@@ -245,14 +247,14 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
using search2 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2));
filter_algo =
filter_result.algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, filter_algo));
filter_result = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, filter_result.algo));
#endif
}
......@@ -278,7 +280,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data + filter_offset * g,
args1.cdesc.desc(),
data_algo,
fwd_result.algo,
&beta,
args1.odesc.desc(),
dx_data + x_offset * g,
......@@ -295,7 +297,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args1.wdesc.desc(),
filter_data + filter_offset * g,
args1.cdesc.desc(),
data_algo,
fwd_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......@@ -338,7 +340,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args2.idesc.desc(),
dout_data + dout_offset * g,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args2.wdesc.desc(),
dfilter_data + filter_offset * g,
......@@ -355,7 +357,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
args2.odesc.desc(),
x_data + x_offset * g,
args2.cdesc.desc(),
filter_algo,
filter_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......@@ -653,22 +655,17 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
dilations_,
dtype};
#ifdef PADDLE_WITH_HIP
miopenConvBwdDataAlgorithm_t bwd_algo1 =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvBwdDataAlgorithm_t bwd_algo2 =
static_cast<miopenConvBwdDataAlgorithm_t>(0);
miopenConvFwdAlgorithm_t data_algo = static_cast<miopenConvFwdAlgorithm_t>(0);
miopenConvBwdWeightsAlgorithm_t filter_algo =
static_cast<miopenConvBwdWeightsAlgorithm_t>(0);
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result1;
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result2;
paddle::operators::SearchResult<miopenConvBwdWeightsAlgorithm_t>
filter_result;
paddle::operators::SearchResult<miopenConvFwdAlgorithm_t> fwd_result;
#else
cudnnConvolutionBwdDataAlgo_t bwd_algo1 =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionBwdDataAlgo_t bwd_algo2 =
static_cast<cudnnConvolutionBwdDataAlgo_t>(0);
cudnnConvolutionFwdAlgo_t data_algo =
static_cast<cudnnConvolutionFwdAlgo_t>(0);
cudnnConvolutionBwdFilterAlgo_t filter_algo =
static_cast<cudnnConvolutionBwdFilterAlgo_t>(0);
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result1;
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result2;
paddle::operators::SearchResult<cudnnConvolutionBwdFilterAlgo_t>
filter_result;
paddle::operators::SearchResult<cudnnConvolutionFwdAlgo_t> fwd_result;
#endif
auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW);
......@@ -696,13 +693,13 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search1 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = search1::GetWorkspaceSize(args1);
bwd_algo1 =
bwd_result1.algo =
search1::Find<T>(args1, false, deterministic, workspace_size, ctx);
#else
using search1 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo1 = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = search1::GetWorkspaceSize(args1, bwd_algo1);
bwd_result1 = search1::Find<T>(args1, false, deterministic, ctx);
workspace_size = search1::GetWorkspaceSize(args1, bwd_result1.algo);
#endif
ddfilter_ = ddfilter.data<T>();
......@@ -720,14 +717,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search2 =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2));
bwd_algo2 =
bwd_result2.algo =
search2::Find<T>(args2, false, deterministic, workspace_size, ctx);
#else
using search2 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_algo2 = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search2::GetWorkspaceSize(args2, bwd_algo2));
bwd_result2 = search2::Find<T>(args2, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search2::GetWorkspaceSize(args2, bwd_result2.algo));
#endif
}
......@@ -736,9 +733,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.handle = handle;
args3.idesc.set(transformed_dout, iwo_group);
args3.wdesc.set(*dfilter, layout, iwo_group);
args3.odesc.set(transformed_ddx_channel, iwo_group);
args3.cdesc.set(dtype,
padding_common,
strides,
......@@ -749,14 +744,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search3 =
paddle::operators::SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size = std::max(workspace_size, search3::GetWorkspaceSize(args3));
filter_algo =
filter_result.algo =
search3::Find<T>(args3, false, deterministic, workspace_size, ctx);
#else
using search3 =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_algo = search3::Find<T>(args3, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search3::GetWorkspaceSize(args3, filter_algo));
filter_result = search3::Find<T>(args3, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search3::GetWorkspaceSize(args3, filter_result.algo));
#endif
}
......@@ -777,14 +772,14 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
using search4 =
paddle::operators::SearchAlgorithm<miopenConvFwdAlgorithm_t>;
workspace_size = std::max(workspace_size, search4::GetWorkspaceSize(args4));
data_algo =
fwd_result.algo =
search4::Find<T>(args4, false, deterministic, workspace_size, ctx);
#else
using search4 =
paddle::operators::SearchAlgorithm<cudnnConvolutionFwdAlgoPerf_t>;
data_algo = search4::Find<T>(args4, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search4::GetWorkspaceSize(args4, data_algo));
fwd_result = search4::Find<T>(args4, false, deterministic, ctx);
workspace_size = std::max(
workspace_size, search4::GetWorkspaceSize(args4, fwd_result.algo));
#endif
}
......@@ -831,7 +826,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args1.wdesc.desc(),
filter_ + i * group_offset_filter,
args1.cdesc.desc(),
bwd_algo1,
bwd_result1.algo,
&beta,
args1.idesc.desc(),
transformed_ddout_channel_ + i * group_offset_out,
......@@ -850,7 +845,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args1.odesc.desc(),
ddx_ + i * group_offset_in,
args1.cdesc.desc(),
bwd_algo1,
bwd_result1.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -877,7 +872,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args2.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args2.cdesc.desc(),
bwd_algo2,
bwd_result2.algo,
&beta,
args2.idesc.desc(),
conv_x_ddfilter_data + i * group_offset_out,
......@@ -908,7 +903,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args2.odesc.desc(),
x_ + i * group_offset_in,
args2.cdesc.desc(),
bwd_algo2,
bwd_result2.algo,
workspace_ptr,
workspace_size,
&alpha,
......@@ -964,7 +959,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.idesc.desc(),
transformed_dout_channel_ + i * group_offset_out,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
&beta,
args3.wdesc.desc(),
dfilter_ + i * group_offset_filter,
......@@ -983,7 +978,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args3.odesc.desc(),
ddx_ + i * group_offset_in,
args3.cdesc.desc(),
filter_algo,
filter_result.algo,
workspace_ptr,
workspace_size,
&beta,
......@@ -1009,7 +1004,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args4.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args4.cdesc.desc(),
data_algo,
fwd_result.algo,
&beta,
args4.odesc.desc(),
transformed_dx_ + i * group_offset_in,
......@@ -1028,7 +1023,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args4.wdesc.desc(),
ddfilter_ + i * group_offset_filter,
args4.cdesc.desc(),
data_algo,
fwd_result.algo,
workspace_ptr,
workspace_size,
&beta,
......
......@@ -217,16 +217,19 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
c_groups);
#ifdef PADDLE_WITH_HIP
paddle::operators::SearchResult<miopenConvBwdDataAlgorithm_t> bwd_result;
using search =
paddle::operators::SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size = std::max(workspace_size, search::GetWorkspaceSize(args));
algo = search::Find<T>(args, false, deterministic, workspace_size, ctx);
bwd_result.algo =
search::Find<T>(args, false, deterministic, workspace_size, ctx);
#else
paddle::operators::SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
using search =
paddle::operators::SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
algo = search::Find<T>(args, false, deterministic, ctx);
bwd_result = search::Find<T>(args, false, deterministic, ctx);
workspace_size =
std::max(workspace_size, search::GetWorkspaceSize(args, algo));
std::max(workspace_size, search::GetWorkspaceSize(args, bwd_result.algo));
#endif
// ------------------- cudnn conv transpose forward ---------------------
......@@ -247,7 +250,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
args.wdesc.desc(),
filter_data + filter_offset * g,
args.cdesc.desc(),
algo,
bwd_result.algo,
&beta,
args.idesc.desc(),
transformed_out_data + out_offset * g,
......@@ -264,7 +267,7 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
args.odesc.desc(),
x_data + x_offset * g,
args.cdesc.desc(),
algo,
bwd_result.algo,
cudnn_workspace,
workspace_size,
&beta,
......
......@@ -36,7 +36,7 @@
#include "paddle/phi/kernels/funcs/batch_norm_utils.h"
DECLARE_bool(cudnn_deterministic);
DECLARE_uint64(conv_workspace_size_limit);
DECLARE_int64(conv_workspace_size_limit);
DECLARE_bool(cudnn_exhaustive_search);
namespace phi {
......
......@@ -43,6 +43,16 @@ def static_program(net, data):
return loss
def set_flags(enable_autotune):
if paddle.is_compiled_with_cuda():
if enable_autotune:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': -1})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 1})
else:
paddle.set_flags({'FLAGS_conv_workspace_size_limit': 512})
paddle.set_flags({'FLAGS_cudnn_exhaustive_search': 0})
class TestAutoTune(unittest.TestCase):
def test_autotune(self):
paddle.fluid.core.disable_autotune()
......@@ -61,6 +71,7 @@ class TestAutoTune(unittest.TestCase):
class TestDygraphAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
......@@ -107,6 +118,7 @@ class TestDygraphAutoTuneStatus(TestAutoTune):
class TestStaticAutoTuneStatus(TestAutoTune):
def run_program(self, enable_autotune):
paddle.enable_static()
set_flags(enable_autotune)
if enable_autotune:
paddle.fluid.core.enable_autotune()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册