未验证 提交 bc47e7ac 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance the implementation of some conv functions. (#47281)

上级 2f3ad5ab
...@@ -75,9 +75,9 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results, ...@@ -75,9 +75,9 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results,
SearchResult<AlgoT>* search_result) { SearchResult<AlgoT>* search_result) {
int best_algo_idx = -1; int best_algo_idx = -1;
for (size_t i = 0; i < perf_results.size(); ++i) { for (size_t i = 0; i < perf_results.size(); ++i) {
auto result = perf_results[i]; const auto& result = perf_results[i];
if (result.status == CUDNN_STATUS_SUCCESS && if (result.status == CUDNN_STATUS_SUCCESS &&
result.memory < workspace_limit) { result.memory <= workspace_limit) {
if (best_algo_idx == -1) { if (best_algo_idx == -1) {
// The algorithm which has minimize time cost and need a workspace_size // The algorithm which has minimize time cost and need a workspace_size
// fitting the workspace_limit constraint. // fitting the workspace_limit constraint.
...@@ -87,8 +87,10 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results, ...@@ -87,8 +87,10 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results,
break; break;
} }
} else { } else {
float best_algo_time = perf_results[best_algo_idx].time; // Compared to the next suboptimal algorithm, if the best one only has
if ((result.time - best_algo_time) / best_algo_time < 0.01) { // 1% performance difference, we'd like to pick the one which need less
// memory.
if (result.time < 1.01 * perf_results[best_algo_idx].time) {
best_algo_idx = (result.memory < perf_results[best_algo_idx].memory) best_algo_idx = (result.memory < perf_results[best_algo_idx].memory)
? i ? i
: best_algo_idx; : best_algo_idx;
...@@ -98,9 +100,15 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results, ...@@ -98,9 +100,15 @@ void ChooseAlgoByWorkspace(const std::vector<PerfT>& perf_results,
} }
} }
if (best_algo_idx != -1) { if (best_algo_idx != -1) {
search_result->algo = perf_results[best_algo_idx].algo; const auto& result = perf_results[best_algo_idx];
search_result->time = perf_results[best_algo_idx].time; search_result->algo = result.algo;
search_result->workspace_size = perf_results[best_algo_idx].memory; search_result->time = result.time;
search_result->workspace_size = result.memory;
auto math_type_str = (result.mathType == CUDNN_TENSOR_OP_MATH) ? "T" : "F";
VLOG(3) << "Choose algo=" << result.algo
<< ", tensor_core=" << math_type_str << ", time=" << result.time
<< " ms, memory=" << ToMegaBytes(result.memory)
<< " MB, status=" << result.status;
} else { } else {
VLOG(3) << "Can not find an algorithm that requires memory < " VLOG(3) << "Can not find an algorithm that requires memory < "
<< ToMegaBytes(workspace_limit) << " MB"; << ToMegaBytes(workspace_limit) << " MB";
...@@ -626,7 +634,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -626,7 +634,8 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> {
perf_results, perf_results,
perf_results.size(), perf_results.size(),
workspace_size_limit); workspace_size_limit);
ChooseAlgo(perf_results, workspace_size_limit, &result); ChooseAlgoByWorkspace<PerfT, AlgoT>(
perf_results, workspace_size_limit, &result);
} }
result.workspace_size = GetWorkspaceSize(args, result.algo); result.workspace_size = GetWorkspaceSize(args, result.algo);
...@@ -673,42 +682,6 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> { ...@@ -673,42 +682,6 @@ struct SearchAlgorithmBase<cudnnConvolutionBwdFilterAlgoPerf_t> {
return workspace_size_limit; return workspace_size_limit;
} }
} }
static void ChooseAlgo(const std::vector<PerfT>& perf_results,
size_t workspace_limit,
SearchResult<AlgoT>* algo_result) {
for (size_t i = 0; i != perf_results.size(); ++i) {
const auto& result = perf_results[i];
if (result.status == CUDNN_STATUS_SUCCESS &&
(result.memory <= workspace_limit)) {
if ((result.mathType == CUDNN_TENSOR_OP_MATH) &&
(i != perf_results.size() - 1)) {
const auto& next_result = perf_results[i + 1];
if (next_result.status == CUDNN_STATUS_SUCCESS &&
next_result.algo == result.algo &&
next_result.memory == result.memory &&
next_result.mathType != CUDNN_TENSOR_OP_MATH &&
next_result.time < 1.01 * result.time) {
// Skip over this result- it's not really a Tensor Core algo.
// Because it is only 1% performance difference.
// Prefer to choose the next equivalent non-Tensor Core algo.
continue;
}
}
algo_result->algo = result.algo;
algo_result->time = result.time;
auto math_type_str = "0";
if (result.mathType == CUDNN_TENSOR_OP_MATH) {
math_type_str = "1";
}
VLOG(3) << " choose algo: " << result.algo
<< ", TC: " << math_type_str << ", time: " << result.time
<< " ms, wksp = " << result.memory
<< ", status = " << result.status;
break;
}
}
}
}; };
template <typename PerfT> template <typename PerfT>
...@@ -735,7 +708,7 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> { ...@@ -735,7 +708,7 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
// Auto tune is only enabled between specified range. // Auto tune is only enabled between specified range.
// 3. After auto-tune process, run cached algorithm if cached, run // 3. After auto-tune process, run cached algorithm if cached, run
// default mode for the rest. // default mode for the rest.
auto key = args.Convert2ConvCacheKey<T>(); auto key = args.ConvertToConvCacheKey<T>();
auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv( auto& cache = phi::autotune::AutoTuneCache::Instance().GetConv(
SearchAlgorithmBase<PerfT>::kAlgoType); SearchAlgorithmBase<PerfT>::kAlgoType);
bool find_in_cache = cache.Find(key); bool find_in_cache = cache.Find(key);
...@@ -746,7 +719,6 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> { ...@@ -746,7 +719,6 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
result.exhaustive_search = t.exhaustive_search; result.exhaustive_search = t.exhaustive_search;
} }
if (!result.exhaustive_search) { if (!result.exhaustive_search) {
bool need_update_cache = false;
// In conv2d_tranpose, enable_autotune is set to false because some // In conv2d_tranpose, enable_autotune is set to false because some
// algorithm picked by exhaustive search method produce wrong result. // algorithm picked by exhaustive search method produce wrong result.
use_autotune = enable_autotune && use_autotune = enable_autotune &&
...@@ -757,17 +729,18 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> { ...@@ -757,17 +729,18 @@ struct SearchAlgorithm : public SearchAlgorithmBase<PerfT> {
result = result =
SearchAlgorithmBase<PerfT>::template FindAlgoExhaustiveSearch<T>( SearchAlgorithmBase<PerfT>::template FindAlgoExhaustiveSearch<T>(
args, ctx); args, ctx);
need_update_cache = true; cache.Set(key,
phi::autotune::ConvAutoTuneResult(
static_cast<int64_t>(result.algo),
result.workspace_size,
true));
} else if (!find_in_cache) { } else if (!find_in_cache) {
result = SearchAlgorithmBase<PerfT>::FindAlgoHeuristic(args, ctx); result = SearchAlgorithmBase<PerfT>::FindAlgoHeuristic(args, ctx);
need_update_cache = true; cache.Set(key,
} phi::autotune::ConvAutoTuneResult(
if (need_update_cache) { static_cast<int64_t>(result.algo),
phi::autotune::ConvAutoTuneResult node( result.workspace_size,
static_cast<int64_t>(result.algo), false));
result.workspace_size,
exhaustive_search || use_autotune);
cache.Set(key, node);
} }
} }
} }
......
...@@ -69,10 +69,15 @@ static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) { ...@@ -69,10 +69,15 @@ static std::ostream& operator<<(std::ostream& out, const std::vector<T>& v) {
template <typename HandleT, typename DataT> template <typename HandleT, typename DataT>
struct ConvArgsBase { struct ConvArgsBase {
HandleT handle; HandleT handle;
paddle::platform::TensorDescriptor idesc, odesc; paddle::platform::TensorDescriptor idesc;
paddle::platform::TensorDescriptor odesc;
paddle::platform::FilterDescriptor wdesc; paddle::platform::FilterDescriptor wdesc;
paddle::platform::ConvolutionDescriptor cdesc; paddle::platform::ConvolutionDescriptor cdesc;
const phi::DenseTensor *x, *w, *o;
const phi::DenseTensor* x = nullptr;
const phi::DenseTensor* w = nullptr;
const phi::DenseTensor* o = nullptr;
DataT cudnn_dtype; DataT cudnn_dtype;
// strides // strides
...@@ -88,7 +93,8 @@ struct ConvArgsBase { ...@@ -88,7 +93,8 @@ struct ConvArgsBase {
// data foramt // data foramt
GPUDNNDataLayout data_layout; GPUDNNDataLayout data_layout;
ConvArgsBase(const phi::DenseTensor* x, ConvArgsBase(const HandleT& h,
const phi::DenseTensor* x,
const phi::DenseTensor* w, const phi::DenseTensor* w,
const phi::DenseTensor* o, const phi::DenseTensor* o,
const std::vector<int> s, const std::vector<int> s,
...@@ -97,7 +103,8 @@ struct ConvArgsBase { ...@@ -97,7 +103,8 @@ struct ConvArgsBase {
DataT dtype, DataT dtype,
int g, int g,
GPUDNNDataLayout layout) GPUDNNDataLayout layout)
: x(x), : handle(h),
x(x),
w(w), w(w),
o(o), o(o),
s(s), s(s),
...@@ -108,7 +115,7 @@ struct ConvArgsBase { ...@@ -108,7 +115,7 @@ struct ConvArgsBase {
data_layout(layout) {} data_layout(layout) {}
template <typename T> template <typename T>
phi::autotune::ConvCacheKey Convert2ConvCacheKey() const { phi::autotune::ConvCacheKey ConvertToConvCacheKey() const {
auto x_shape = phi::vectorize(x->dims()); auto x_shape = phi::vectorize(x->dims());
auto w_shape = phi::vectorize(w->dims()); auto w_shape = phi::vectorize(w->dims());
VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape VLOG(10) << "[ConvArgs] x_dims=" << x_shape << ", w_dims=" << w_shape
......
...@@ -257,7 +257,8 @@ void ConvCudnnGradGradKernel( ...@@ -257,7 +257,8 @@ void ConvCudnnGradGradKernel(
auto layout = paddle::platform::GetCudnnTensorFormat( auto layout = paddle::platform::GetCudnnTensorFormat(
paddle::platform::DataLayout::kNCHW); paddle::platform::DataLayout::kNCHW);
ConvArgs args1{&transformed_ddX, ConvArgs args1{handle,
&transformed_ddX,
W, W,
&transformed_ddO_channel, &transformed_ddO_channel,
strides, strides,
...@@ -266,7 +267,8 @@ void ConvCudnnGradGradKernel( ...@@ -266,7 +267,8 @@ void ConvCudnnGradGradKernel(
dtype, dtype,
groups, groups,
paddle::platform::DataLayout::kNCHW}; paddle::platform::DataLayout::kNCHW};
ConvArgs args2{&transformed_X, ConvArgs args2{handle,
&transformed_X,
ddW, ddW,
&transformed_ddO_channel, &transformed_ddO_channel,
strides, strides,
...@@ -275,7 +277,8 @@ void ConvCudnnGradGradKernel( ...@@ -275,7 +277,8 @@ void ConvCudnnGradGradKernel(
dtype, dtype,
groups, groups,
paddle::platform::DataLayout::kNCHW}; paddle::platform::DataLayout::kNCHW};
ConvArgs args3{&transformed_ddX, ConvArgs args3{handle,
&transformed_ddX,
dW, dW,
&transformed_dO_channel, &transformed_dO_channel,
strides, strides,
...@@ -284,7 +287,8 @@ void ConvCudnnGradGradKernel( ...@@ -284,7 +287,8 @@ void ConvCudnnGradGradKernel(
dtype, dtype,
groups, groups,
paddle::platform::DataLayout::kNCHW}; paddle::platform::DataLayout::kNCHW};
ConvArgs args4{&transformed_dX, ConvArgs args4{handle,
&transformed_dX,
ddW, ddW,
&transformed_dO_channel, &transformed_dO_channel,
strides, strides,
...@@ -314,7 +318,6 @@ void ConvCudnnGradGradKernel( ...@@ -314,7 +318,6 @@ void ConvCudnnGradGradKernel(
ddy = ddO->data<T>(); ddy = ddO->data<T>();
transformed_ddy_channel = transformed_ddO_channel.data<T>(); transformed_ddy_channel = transformed_ddO_channel.data<T>();
if (ddX) { if (ddX) {
args1.handle = handle;
args1.idesc.set(transformed_ddX, iwo_group); args1.idesc.set(transformed_ddX, iwo_group);
args1.wdesc.set(*W, layout, iwo_group); args1.wdesc.set(*W, layout, iwo_group);
args1.odesc.set(transformed_ddO_channel, iwo_group); args1.odesc.set(transformed_ddO_channel, iwo_group);
...@@ -339,7 +342,6 @@ void ConvCudnnGradGradKernel( ...@@ -339,7 +342,6 @@ void ConvCudnnGradGradKernel(
if (ddW) { if (ddW) {
ddw = ddW->data<T>(); ddw = ddW->data<T>();
args2.handle = handle;
args2.idesc.set(transformed_X, iwo_group); args2.idesc.set(transformed_X, iwo_group);
args2.wdesc.set(*ddW, layout, iwo_group); args2.wdesc.set(*ddW, layout, iwo_group);
args2.odesc.set(transformed_ddO_channel, iwo_group); args2.odesc.set(transformed_ddO_channel, iwo_group);
...@@ -367,7 +369,6 @@ void ConvCudnnGradGradKernel( ...@@ -367,7 +369,6 @@ void ConvCudnnGradGradKernel(
if (dW && ddX) { if (dW && ddX) {
dw = dW->data<T>(); dw = dW->data<T>();
args3.handle = handle;
args3.idesc.set(transformed_ddX, iwo_group); args3.idesc.set(transformed_ddX, iwo_group);
args3.wdesc.set(*dW, layout, iwo_group); args3.wdesc.set(*dW, layout, iwo_group);
args3.odesc.set(transformed_dO_channel, iwo_group); args3.odesc.set(transformed_dO_channel, iwo_group);
...@@ -395,7 +396,6 @@ void ConvCudnnGradGradKernel( ...@@ -395,7 +396,6 @@ void ConvCudnnGradGradKernel(
if (ddW && dX) { if (ddW && dX) {
transformed_dx = transformed_dX.data<T>(); transformed_dx = transformed_dX.data<T>();
args4.handle = handle;
args4.idesc.set(transformed_dX, iwo_group); args4.idesc.set(transformed_dX, iwo_group);
args4.wdesc.set(*ddW, layout, iwo_group); args4.wdesc.set(*ddW, layout, iwo_group);
args4.odesc.set(transformed_dO_channel, iwo_group); args4.odesc.set(transformed_dO_channel, iwo_group);
...@@ -444,13 +444,13 @@ void ConvCudnnGradGradKernel( ...@@ -444,13 +444,13 @@ void ConvCudnnGradGradKernel(
// ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f : // ScalingParamType<T> beta = ctx.Attr<bool>("use_addto") ? 1.0f :
// 0.0f; // 0.0f;
// VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr<bool>("use_addto"); // VLOG(4) << "Conv_grad_grad: use_addto = " << ctx.Attr<bool>("use_addto");
auto wkspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
if (ddO) { if (ddO) {
if (ddX) { if (ddX) {
ddx = transformed_ddX.data<T>(); ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward( paddle::platform::dynload::miopenConvolutionForward(
...@@ -471,7 +471,7 @@ void ConvCudnnGradGradKernel( ...@@ -471,7 +471,7 @@ void ConvCudnnGradGradKernel(
workspace_size); workspace_size);
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward( paddle::platform::dynload::cudnnConvolutionForward(
...@@ -496,7 +496,7 @@ void ConvCudnnGradGradKernel( ...@@ -496,7 +496,7 @@ void ConvCudnnGradGradKernel(
if (ddW) { if (ddW) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// MIOPEN ONLY support beta to be 0.0f // MIOPEN ONLY support beta to be 0.0f
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionForward( paddle::platform::dynload::miopenConvolutionForward(
...@@ -517,7 +517,7 @@ void ConvCudnnGradGradKernel( ...@@ -517,7 +517,7 @@ void ConvCudnnGradGradKernel(
workspace_size); workspace_size);
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionForward( paddle::platform::dynload::cudnnConvolutionForward(
...@@ -547,7 +547,7 @@ void ConvCudnnGradGradKernel( ...@@ -547,7 +547,7 @@ void ConvCudnnGradGradKernel(
if (dW && ddX) { if (dW && ddX) {
ddx = transformed_ddX.data<T>(); ddx = transformed_ddX.data<T>();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardWeights( paddle::platform::dynload::miopenConvolutionBackwardWeights(
...@@ -568,7 +568,7 @@ void ConvCudnnGradGradKernel( ...@@ -568,7 +568,7 @@ void ConvCudnnGradGradKernel(
workspace_size); workspace_size);
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardFilter( paddle::platform::dynload::cudnnConvolutionBackwardFilter(
...@@ -594,7 +594,7 @@ void ConvCudnnGradGradKernel( ...@@ -594,7 +594,7 @@ void ConvCudnnGradGradKernel(
if (dX && ddW) { if (dX && ddW) {
ddw = ddW->data<T>(); ddw = ddW->data<T>();
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::miopenConvolutionBackwardData( paddle::platform::dynload::miopenConvolutionBackwardData(
...@@ -615,7 +615,7 @@ void ConvCudnnGradGradKernel( ...@@ -615,7 +615,7 @@ void ConvCudnnGradGradKernel(
workspace_size); workspace_size);
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnConvolutionBackwardData( paddle::platform::dynload::cudnnConvolutionBackwardData(
......
...@@ -251,12 +251,14 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -251,12 +251,14 @@ void ConvCudnnGradKernel(const Context& ctx,
T* input_grad_data = nullptr; T* input_grad_data = nullptr;
T* transformed_input_grad_data = nullptr; T* transformed_input_grad_data = nullptr;
auto handle = ctx.cudnn_handle();
paddle::platform::DataLayout layout = paddle::platform::DataLayout layout =
compute_format == paddle::platform::DataLayout::kNHWC compute_format == paddle::platform::DataLayout::kNHWC
? paddle::platform::DataLayout::kNHWC ? paddle::platform::DataLayout::kNHWC
: paddle::platform::DataLayout::kNCHW; : paddle::platform::DataLayout::kNCHW;
ConvArgs args1{&transformed_input_grad, ConvArgs args1{handle,
&transformed_input_grad,
&transformed_filter_channel, &transformed_filter_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
...@@ -265,7 +267,8 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -265,7 +267,8 @@ void ConvCudnnGradKernel(const Context& ctx,
dtype, dtype,
groups, groups,
layout}; layout};
ConvArgs args2{&transformed_input, ConvArgs args2{handle,
&transformed_input,
&transformed_filter_grad_channel, &transformed_filter_grad_channel,
&transformed_output_grad_channel, &transformed_output_grad_channel,
strides, strides,
...@@ -275,7 +278,6 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -275,7 +278,6 @@ void ConvCudnnGradKernel(const Context& ctx,
groups, groups,
layout}; layout};
auto handle = ctx.cudnn_handle();
// TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout // TODO(phlrain): replace paddle::platform::DataLaytout to phi::DataLayout
if (transformed_input.dims().size() == 5) { if (transformed_input.dims().size() == 5) {
...@@ -332,10 +334,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -332,10 +334,7 @@ void ConvCudnnGradKernel(const Context& ctx,
SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result; SearchResult<cudnnConvolutionBwdDataAlgo_t> bwd_result;
SearchResult<cudnnConvolutionBwdFilterAlgo_t> filter_result; SearchResult<cudnnConvolutionBwdFilterAlgo_t> filter_result;
#endif #endif
// input data workspace_size size_t workspace_size = 0;
size_t workspace_size_d = 0;
// weight workspace_size
size_t workspace_size_w = 0;
int iwo_groups = groups; int iwo_groups = groups;
int c_groups = 1; int c_groups = 1;
...@@ -350,7 +349,6 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -350,7 +349,6 @@ void ConvCudnnGradKernel(const Context& ctx,
input_grad_data = input_grad->data<T>(); input_grad_data = input_grad->data<T>();
transformed_input_grad_data = transformed_input_grad.data<T>(); transformed_input_grad_data = transformed_input_grad.data<T>();
args1.handle = handle;
args1.idesc.set(transformed_input_grad, layout_tensor); args1.idesc.set(transformed_input_grad, layout_tensor);
args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups); args1.wdesc.set(transformed_filter_channel, layout_tensor, iwo_groups);
args1.odesc.set(transformed_output_grad_channel, layout_tensor); args1.odesc.set(transformed_output_grad_channel, layout_tensor);
...@@ -363,21 +361,20 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -363,21 +361,20 @@ void ConvCudnnGradKernel(const Context& ctx,
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
using search1 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>; using search1 = SearchAlgorithm<miopenConvBwdDataAlgorithm_t>;
workspace_size_d = workspace_size = std::max(workspace_size, search1::GetWorkspaceSize(args1));
std::max(workspace_size_d, search1::GetWorkspaceSize(args1));
bwd_result.algo = search1::Find<T>( bwd_result.algo = search1::Find<T>(
args1, exhaustive_search, deterministic, workspace_size_d, ctx); args1, exhaustive_search, deterministic, workspace_size, ctx);
#else #else
using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>; using search1 = SearchAlgorithm<cudnnConvolutionBwdDataAlgoPerf_t>;
bwd_result = search1::Find<T>(ctx, args1, exhaustive_search, deterministic); bwd_result = search1::Find<T>(ctx, args1, exhaustive_search, deterministic);
workspace_size_d = std::max(workspace_size_d, bwd_result.workspace_size); workspace_size = std::max(workspace_size, bwd_result.workspace_size);
#endif #endif
} }
if (filter_grad) { if (filter_grad) {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
filter_grad_data = transformed_filter_grad_channel.data<T>(); filter_grad_data = transformed_filter_grad_channel.data<T>();
args2.handle = handle;
args2.idesc.set(transformed_input, layout_tensor); args2.idesc.set(transformed_input, layout_tensor);
args2.wdesc.set(transformed_filter_grad_channel, layout_tensor, iwo_groups); args2.wdesc.set(transformed_filter_grad_channel, layout_tensor, iwo_groups);
args2.odesc.set(transformed_output_grad_channel, layout_tensor); args2.odesc.set(transformed_output_grad_channel, layout_tensor);
...@@ -389,17 +386,16 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -389,17 +386,16 @@ void ConvCudnnGradKernel(const Context& ctx,
c_groups); c_groups);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
using search2 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>; using search2 = SearchAlgorithm<miopenConvBwdWeightsAlgorithm_t>;
workspace_size_w = workspace_size = std::max(workspace_size, search2::GetWorkspaceSize(args2));
std::max(workspace_size_w, search2::GetWorkspaceSize(args2));
filter_result.algo = search2::Find<T>( filter_result.algo = search2::Find<T>(
args2, exhaustive_search, deterministic, workspace_size_w, ctx); args2, exhaustive_search, deterministic, workspace_size, ctx);
#else #else
using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>; using search2 = SearchAlgorithm<cudnnConvolutionBwdFilterAlgoPerf_t>;
filter_result = filter_result =
search2::Find<T>(ctx, args2, exhaustive_search, deterministic); search2::Find<T>(ctx, args2, exhaustive_search, deterministic);
VLOG(3) << "filter algo: " << filter_result.algo << ", time " VLOG(3) << "filter algo: " << filter_result.algo << ", time "
<< filter_result.time; << filter_result.time;
workspace_size_w = std::max(workspace_size_w, filter_result.workspace_size); workspace_size = std::max(workspace_size, filter_result.workspace_size);
#endif #endif
} }
...@@ -438,9 +434,9 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -438,9 +434,9 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.idesc.desc(), args1.idesc.desc(),
temp_tensor_data, temp_tensor_data,
cudnn_workspace_ptr, cudnn_workspace_ptr,
workspace_size_d)); workspace_size));
}, },
workspace_size_d); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::miopenOpTensor( PADDLE_ENFORCE_GPU_SUCCESS(paddle::platform::dynload::miopenOpTensor(
handle, handle,
miopenTensorOpAdd, miopenTensorOpAdd,
...@@ -470,9 +466,9 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -470,9 +466,9 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.idesc.desc(), args1.idesc.desc(),
transformed_input_grad_data, transformed_input_grad_data,
cudnn_workspace_ptr, cudnn_workspace_ptr,
workspace_size_d)); workspace_size));
}, },
workspace_size_d); workspace_size);
} }
#else #else
...@@ -490,12 +486,12 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -490,12 +486,12 @@ void ConvCudnnGradKernel(const Context& ctx,
args1.cdesc.desc(), args1.cdesc.desc(),
bwd_result.algo, bwd_result.algo,
cudnn_workspace_ptr, cudnn_workspace_ptr,
workspace_size_d, workspace_size,
&beta, &beta,
args1.idesc.desc(), args1.idesc.desc(),
transformed_input_grad_data + i * group_offset_in)); transformed_input_grad_data + i * group_offset_in));
}, },
workspace_size_d); workspace_size);
} }
#endif #endif
if (!is_sys_pad) { if (!is_sys_pad) {
...@@ -551,9 +547,9 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -551,9 +547,9 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.wdesc.desc(), args2.wdesc.desc(),
filter_grad_data, filter_grad_data,
cudnn_workspace_ptr, cudnn_workspace_ptr,
workspace_size_w)); workspace_size));
}, },
workspace_size_w); workspace_size);
#else #else
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
workspace_handle.RunFunc( workspace_handle.RunFunc(
...@@ -569,12 +565,12 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -569,12 +565,12 @@ void ConvCudnnGradKernel(const Context& ctx,
args2.cdesc.desc(), args2.cdesc.desc(),
filter_result.algo, filter_result.algo,
cudnn_workspace_ptr, cudnn_workspace_ptr,
workspace_size_w, workspace_size,
&beta_filter, &beta_filter,
args2.wdesc.desc(), args2.wdesc.desc(),
filter_grad_data + i * group_offset_filter)); filter_grad_data + i * group_offset_filter));
}, },
workspace_size_w); workspace_size);
} }
#endif #endif
......
...@@ -201,11 +201,14 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -201,11 +201,14 @@ void ConvCudnnKernel(const Context& ctx,
} }
const T* input_data = transformed_input.data<T>(); const T* input_data = transformed_input.data<T>();
const T* filter_data = transformed_filter_channel.data<T>(); const T* filter_data = transformed_filter_channel.data<T>();
auto handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_input, ConvArgs args{handle,
&transformed_input,
&transformed_filter_channel, &transformed_filter_channel,
&transformed_output, &transformed_output,
strides, strides,
...@@ -215,8 +218,6 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -215,8 +218,6 @@ void ConvCudnnKernel(const Context& ctx,
groups, groups,
compute_format}; compute_format};
auto handle = ctx.cudnn_handle();
auto workspace_handle = ctx.cudnn_workspace_handle();
paddle::platform::DataLayout layout = paddle::platform::DataLayout layout =
compute_format == paddle::platform::DataLayout::kNHWC compute_format == paddle::platform::DataLayout::kNHWC
? paddle::platform::DataLayout::kNHWC ? paddle::platform::DataLayout::kNHWC
...@@ -228,8 +229,6 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -228,8 +229,6 @@ void ConvCudnnKernel(const Context& ctx,
} }
auto layout_format = paddle::platform::GetCudnnTensorFormat(layout); auto layout_format = paddle::platform::GetCudnnTensorFormat(layout);
args.handle = handle;
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
// MIOPEN need to set groups in cdesc in miopen_desc.h // MIOPEN need to set groups in cdesc in miopen_desc.h
args.cdesc.set(dtype, args.cdesc.set(dtype,
......
...@@ -172,8 +172,10 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -172,8 +172,10 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
#endif #endif
auto dtype = paddle::platform::CudnnDataType<T>::type; auto dtype = paddle::platform::CudnnDataType<T>::type;
auto handle = ctx.cudnn_handle();
ConvArgs args1{&transformed_dout, ConvArgs args1{handle,
&transformed_dout,
&filter, &filter,
&x_transpose, &x_transpose,
strides, strides,
...@@ -182,7 +184,8 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -182,7 +184,8 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
dtype, dtype,
groups, groups,
layout}; layout};
ConvArgs args2{&transformed_dout, ConvArgs args2{handle,
&transformed_dout,
&filter, &filter,
&x_transpose, &x_transpose,
strides, strides,
...@@ -202,14 +205,13 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -202,14 +205,13 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
auto layout_tensor = paddle::platform::GetCudnnTensorFormat(layout); auto layout_tensor = paddle::platform::GetCudnnTensorFormat(layout);
size_t workspace_size = 0; size_t workspace_size = 0;
auto handle = ctx.cudnn_handle();
bool deterministic = FLAGS_cudnn_deterministic; bool deterministic = FLAGS_cudnn_deterministic;
T* dx_data = nullptr; T* dx_data = nullptr;
T* dfilter_data = nullptr; T* dfilter_data = nullptr;
if (dx) { if (dx) {
dx_data = ctx.template Alloc<T>(dx); dx_data = ctx.template Alloc<T>(dx);
args1.handle = handle;
args1.idesc.set(transformed_dout, iwo_groups); args1.idesc.set(transformed_dout, iwo_groups);
args1.wdesc.set(filter, layout_tensor, iwo_groups); args1.wdesc.set(filter, layout_tensor, iwo_groups);
args1.odesc.set(x_transpose, iwo_groups); args1.odesc.set(x_transpose, iwo_groups);
...@@ -234,7 +236,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx, ...@@ -234,7 +236,7 @@ void ConvTransposeGradRawGPUDNNKernel(const Context& ctx,
if (dfilter) { if (dfilter) {
dfilter_data = ctx.template Alloc<T>(dfilter); dfilter_data = ctx.template Alloc<T>(dfilter);
args2.handle = handle;
args2.idesc.set(transformed_dout, iwo_groups); args2.idesc.set(transformed_dout, iwo_groups);
args2.wdesc.set(*dfilter, layout_tensor, iwo_groups); args2.wdesc.set(*dfilter, layout_tensor, iwo_groups);
args2.odesc.set(x_transpose, iwo_groups); args2.odesc.set(x_transpose, iwo_groups);
...@@ -625,7 +627,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -625,7 +627,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
auto handle = ctx.cudnn_handle(); auto handle = ctx.cudnn_handle();
auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW); auto layout = paddle::platform::GetCudnnTensorFormat(GPUDNNDataLayout::kNCHW);
ConvArgs args1{&transformed_ddout_channel, ConvArgs args1{handle,
&transformed_ddout_channel,
&filter, &filter,
&transformed_ddx, &transformed_ddx,
strides, strides,
...@@ -634,7 +637,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -634,7 +637,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
dtype, dtype,
groups, groups,
GPUDNNDataLayout::kNCHW}; GPUDNNDataLayout::kNCHW};
ConvArgs args2{&transformed_ddout_channel, ConvArgs args2{handle,
&transformed_ddout_channel,
&ddfilter, &ddfilter,
&transformed_x, &transformed_x,
strides, strides,
...@@ -644,7 +648,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -644,7 +648,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
groups, groups,
GPUDNNDataLayout::kNCHW}; GPUDNNDataLayout::kNCHW};
ConvArgs args3{&transformed_dout, ConvArgs args3{handle,
&transformed_dout,
dfilter, dfilter,
&transformed_ddx_channel, &transformed_ddx_channel,
strides, strides,
...@@ -653,7 +658,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -653,7 +658,8 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
dtype, dtype,
groups, groups,
GPUDNNDataLayout::kNCHW}; GPUDNNDataLayout::kNCHW};
ConvArgs args4{&transformed_dout, ConvArgs args4{handle,
&transformed_dout,
&ddfilter, &ddfilter,
&transformed_dx_channel, &transformed_dx_channel,
strides, strides,
...@@ -683,7 +689,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -683,7 +689,6 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
ddout_ = ddout->data<T>(); ddout_ = ddout->data<T>();
transformed_ddout_channel_ = transformed_ddout_channel.data<T>(); transformed_ddout_channel_ = transformed_ddout_channel.data<T>();
args1.handle = handle;
args1.idesc.set(transformed_ddout_channel, iwo_group); args1.idesc.set(transformed_ddout_channel, iwo_group);
args1.wdesc.set(filter, layout, iwo_group); args1.wdesc.set(filter, layout, iwo_group);
args1.odesc.set(transformed_ddx, iwo_group); args1.odesc.set(transformed_ddx, iwo_group);
...@@ -730,7 +735,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -730,7 +735,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
if (dfilter) { if (dfilter) {
dfilter_ = dfilter->data<T>(); dfilter_ = dfilter->data<T>();
args3.handle = handle;
args3.idesc.set(transformed_dout, iwo_group); args3.idesc.set(transformed_dout, iwo_group);
args3.wdesc.set(*dfilter, layout, iwo_group); args3.wdesc.set(*dfilter, layout, iwo_group);
args3.odesc.set(transformed_ddx_channel, iwo_group); args3.odesc.set(transformed_ddx_channel, iwo_group);
...@@ -806,13 +811,13 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -806,13 +811,13 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
ScalingParamType<T> alpha = 1.0f; ScalingParamType<T> alpha = 1.0f;
ScalingParamType<T> beta = 0.0f; ScalingParamType<T> beta = 0.0f;
auto wkspace_handle = ctx.cudnn_workspace_handle(); auto workspace_handle = ctx.cudnn_workspace_handle();
if (ddout) { if (ddout) {
ddx_ = transformed_ddx.data<T>(); ddx_ = transformed_ddx.data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionBackwardData( PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionBackwardData(
handle, handle,
...@@ -831,7 +836,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -831,7 +836,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
}, },
workspace_size); workspace_size);
#else // PADDLE_WITH_HIP #else // PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardData( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardData(
handle, handle,
...@@ -858,7 +863,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -858,7 +863,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
DenseTensor conv_x_ddfilter(dout.type()); DenseTensor conv_x_ddfilter(dout.type());
conv_x_ddfilter.Resize(transformed_ddout_channel.dims()); conv_x_ddfilter.Resize(transformed_ddout_channel.dims());
T* conv_x_ddfilter_data = ctx.template Alloc<T>(&conv_x_ddfilter); T* conv_x_ddfilter_data = ctx.template Alloc<T>(&conv_x_ddfilter);
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionBackwardData( PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionBackwardData(
handle, handle,
...@@ -889,7 +894,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -889,7 +894,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
args2.idesc.desc(), args2.idesc.desc(),
transformed_ddout_channel_ + i * group_offset_out)); transformed_ddout_channel_ + i * group_offset_out));
#else // PADDLE_WITH_HIP #else // PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardData( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardData(
handle, handle,
...@@ -944,7 +949,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -944,7 +949,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
ddx_ = transformed_ddx_channel.data<T>(); ddx_ = transformed_ddx_channel.data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::miopenConvolutionBackwardWeights( dynload::miopenConvolutionBackwardWeights(
...@@ -964,7 +969,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -964,7 +969,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
}, },
workspace_size); workspace_size);
#else // PADDLE_WITH_HIP #else // PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardFilter( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionBackwardFilter(
handle, handle,
...@@ -990,7 +995,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -990,7 +995,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
ddfilter_ = ddfilter.data<T>(); ddfilter_ = ddfilter.data<T>();
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionForward( PADDLE_ENFORCE_GPU_SUCCESS(dynload::miopenConvolutionForward(
handle, handle,
...@@ -1009,7 +1014,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel( ...@@ -1009,7 +1014,7 @@ void Conv2dTransposeDoubleGradGPUDNNKernel(
}, },
workspace_size); workspace_size);
#else // PADDLE_WITH_HIP #else // PADDLE_WITH_HIP
wkspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* workspace_ptr) { [&](void* workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionForward( PADDLE_ENFORCE_GPU_SUCCESS(dynload::cudnnConvolutionForward(
handle, handle,
......
...@@ -199,7 +199,8 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, ...@@ -199,7 +199,8 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
auto dtype = paddle::platform::CudnnDataType<T>::type; auto dtype = paddle::platform::CudnnDataType<T>::type;
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ConvArgs args{&transformed_out, ConvArgs args{handle,
&transformed_out,
&filter, &filter,
&transformed_x, &transformed_x,
strides, strides,
...@@ -208,7 +209,6 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx, ...@@ -208,7 +209,6 @@ void ConvTransposeRawGPUDNNKernel(const Context& ctx,
dtype, dtype,
groups, groups,
data_layout}; data_layout};
args.handle = handle;
args.idesc.set(transformed_out, iwo_groups); args.idesc.set(transformed_out, iwo_groups);
args.wdesc.set(filter, layout_tensor, iwo_groups); args.wdesc.set(filter, layout_tensor, iwo_groups);
args.odesc.set(transformed_x, iwo_groups); args.odesc.set(transformed_x, iwo_groups);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册