未验证 提交 17ec1620 编写于 作者: X xiongkun 提交者: GitHub

Revert "make bilinear interpolate stable. (#48644)" (#49307)

This reverts commit e1e8bf72.
上级 a9533953
...@@ -25,8 +25,6 @@ ...@@ -25,8 +25,6 @@
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/primitive/datamover_primitives.h" #include "paddle/phi/kernels/primitive/datamover_primitives.h"
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename T> template <typename T>
...@@ -1039,12 +1037,6 @@ static void Interpolate2DCUDABwd( ...@@ -1039,12 +1037,6 @@ static void Interpolate2DCUDABwd(
#endif #endif
if (optimize_flag & is_nchw) { if (optimize_flag & is_nchw) {
if (FLAGS_cudnn_deterministic) {
VLOG(2)
<< "Run grad kernel of bilinear interpolate 2d with single thread.";
config.block_per_grid = 1;
config.thread_per_block = 1;
}
KeBilinearInterpBwShareMemory<T><<<config.block_per_grid, KeBilinearInterpBwShareMemory<T><<<config.block_per_grid,
config.thread_per_block, config.thread_per_block,
0, 0,
...@@ -1063,27 +1055,21 @@ static void Interpolate2DCUDABwd( ...@@ -1063,27 +1055,21 @@ static void Interpolate2DCUDABwd(
} else if (!optimize_flag & is_nchw) { } else if (!optimize_flag & is_nchw) {
const int num_kernels = n * c * out_h * out_w; const int num_kernels = n * c * out_h * out_w;
const int num_threads = std::min(dev_ctx.GetMaxThreadsPerBlock(), 1024); const int num_threads = std::min(dev_ctx.GetMaxThreadsPerBlock(), 1024);
int block_per_grid = backends::gpu::DivUp(num_kernels, num_threads);
int thread_per_block = num_threads;
if (FLAGS_cudnn_deterministic) {
VLOG(2)
<< "Run grad kernel of bilinear interpolate 2d with single thread.";
block_per_grid = 1;
thread_per_block = 1;
}
KeBilinearInterpNCHWBw<T> KeBilinearInterpNCHWBw<T>
<<<block_per_grid, thread_per_block, 0, dev_ctx.stream()>>>( <<<backends::gpu::DivUp(num_kernels, num_threads),
input_grad_data, num_threads,
in_h, 0,
in_w, dev_ctx.stream()>>>(input_grad_data,
out_h, in_h,
out_w, in_w,
n, out_h,
c, out_w,
ratio_h, n,
ratio_w, c,
output_grad_data, ratio_h,
align_type_value); ratio_w,
output_grad_data,
align_type_value);
} else { } else {
int64_t cw = c * out_w; int64_t cw = c * out_w;
auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw); auto interp_divmods = funcs::FastDivModForInterpolate(c, out_chw, cw);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册