未验证 提交 effd51c7 编写于 作者: MarDino's avatar MarDino 提交者: GitHub

Fix bilinear interp fp16 diff (#49095)

* cast to higher precision type to prevent fp16 diff problem

* fix bilinear backward and add more unittest case
上级 6d88fed0
......@@ -251,7 +251,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
float ratio_w,
const float align_type_value,
bool is_nchw) {
__shared__ T s_data[2][1024];
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
__shared__ MT s_data[2][1024];
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
......@@ -263,18 +264,18 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int out_id_w = tid % out_chw;
const int in_img_size = in_h * in_w;
const int out_img_size = out_h * out_w;
T value = out[out_id_h * out_chw + out_id_w];
MT value = static_cast<MT>(out[out_id_h * out_chw + out_id_w]);
int channel_id = out_id_w / out_img_size;
int out_img_idy = (out_id_w % out_img_size) / out_w;
int out_img_idx = tid % out_w;
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
MT w1lambda, h1lambda, w2lambda, h2lambda;
MT src_w = static_cast<MT>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
MT src_h = static_cast<MT>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
......@@ -289,8 +290,8 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
int bot_right_index = input_index + h_id * in_w + w_id;
int in_top_min_index, in_bot_min_index;
s_data[0][threadIdx.x] = static_cast<T>(0);
s_data[1][threadIdx.x] = static_cast<T>(0);
s_data[0][threadIdx.x] = static_cast<MT>(0);
s_data[1][threadIdx.x] = static_cast<MT>(0);
int remain = nthreads - (tid & (-blockDim.x));
int in_top_max_index =
phi::funcs::blockReduceMax(top_right_index, FINAL_MASK);
......@@ -327,9 +328,9 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
if (threadIdx.x <= upper_limit_share_idx) {
phi::CudaAtomicAdd(&in[in_top_min_index + threadIdx.x],
s_data[0][threadIdx.x]);
static_cast<T>(s_data[0][threadIdx.x]));
phi::CudaAtomicAdd(&in[in_bot_min_index + threadIdx.x],
s_data[1][threadIdx.x]);
static_cast<T>(s_data[1][threadIdx.x]));
}
}
}
......@@ -358,6 +359,7 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
int stride = blockDim.x * gridDim.x;
int num_out = n * num_channels * out_h * out_w;
int num_in = n * num_channels * in_h * in_w;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (; index < num_out; index += stride) {
int index_tmp = index;
......@@ -367,29 +369,29 @@ __global__ void KeBilinearInterpNCHWBw(T* in,
int nc = index_tmp / out_h;
int h1, y_id;
T h1lambda, h0lambda;
T src_y =
static_cast<T>(ratio_h * (h2 + align_type_value) - align_type_value);
MT h1lambda, h0lambda;
MT src_y =
static_cast<MT>(ratio_h * (h2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex(
&h1, &y_id, &h1lambda, &h0lambda, src_y, in_h);
int w1, x_id;
T w1lambda, w0lambda;
T src_x =
static_cast<T>(ratio_w * (w2 + align_type_value) - align_type_value);
MT w1lambda, w0lambda;
MT src_x =
static_cast<MT>(ratio_w * (w2 + align_type_value) - align_type_value);
PreCalculatorForLinearInterpInputIndex(
&w1, &x_id, &w1lambda, &w0lambda, src_x, in_w);
T d2val = out[index];
MT d2val = static_cast<MT>(out[index]);
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1),
h0lambda * w0lambda * d2val);
static_cast<T>(h0lambda * w0lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1, w1 + x_id),
h0lambda * w1lambda * d2val);
static_cast<T>(h0lambda * w1lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1),
h1lambda * w0lambda * d2val);
static_cast<T>(h1lambda * w0lambda * d2val));
phi::CudaAtomicAdd(in + GetInputIndex(nc, in_h, in_w, h1 + y_id, w1 + x_id),
h1lambda * w1lambda * d2val);
static_cast<T>(h1lambda * w1lambda * d2val));
}
}
......@@ -411,6 +413,7 @@ __global__ void KeBilinearInterpBw(T* in,
int stride = blockDim.x * gridDim.x;
int in_chw = in_h * in_w * num_channels;
int nthreads = n * out_chw;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
for (; tid < nthreads; tid += stride) {
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
......@@ -424,28 +427,28 @@ __global__ void KeBilinearInterpBw(T* in,
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, w_id, h_id;
T w1lambda, h1lambda, w2lambda, h2lambda;
T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
MT w1lambda, h1lambda, w2lambda, h2lambda;
MT src_w = static_cast<MT>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
MT src_h = static_cast<MT>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_w);
PreCalculatorForLinearInterpInputIndex(
&in_img_idy, &h_id, &h1lambda, &h2lambda, src_h, in_h);
T value = out[tid];
MT value = static_cast<MT>(out[tid]);
T* in_pos = &in[out_id_h * in_chw + in_img_idy * in_w * num_channels +
in_img_idx * num_channels + channel_id];
phi::CudaAtomicAdd(&in_pos[0], h2lambda * w2lambda * value);
phi::CudaAtomicAdd(&in_pos[0], static_cast<T>(h2lambda * w2lambda * value));
phi::CudaAtomicAdd(&in_pos[w_id * num_channels],
h2lambda * w1lambda * value);
static_cast<T>(h2lambda * w1lambda * value));
phi::CudaAtomicAdd(&in_pos[h_id * in_w * num_channels],
h1lambda * w2lambda * value);
static_cast<T>(h1lambda * w2lambda * value));
phi::CudaAtomicAdd(
&in_pos[h_id * in_w * num_channels + w_id * num_channels],
h1lambda * w1lambda * value);
static_cast<T>(h1lambda * w1lambda * value));
}
}
......
......@@ -209,6 +209,7 @@ __global__ void KeBilinearInterpFw(const T* in,
const float ratio_w,
const float align_type_value,
funcs::FastDivModForInterpolate divmods) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
......@@ -225,11 +226,11 @@ __global__ void KeBilinearInterpFw(const T* in,
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
MT h1lambda, w1lambda, h2lambda, w2lambda;
MT src_w = static_cast<MT>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
MT src_h = static_cast<MT>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
......@@ -241,12 +242,13 @@ __global__ void KeBilinearInterpFw(const T* in,
&in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
out[tid] =
h2lambda *
(w2lambda * in_pos[0] + w1lambda * in_pos[w_id * num_channels]) +
h2lambda * (w2lambda * static_cast<MT>(in_pos[0]) +
w1lambda * static_cast<MT>(in_pos[w_id * num_channels])) +
h1lambda *
(w2lambda * in_pos[h_id * in_img_w * num_channels] +
w1lambda *
in_pos[h_id * in_img_w * num_channels + w_id * num_channels]);
(w2lambda *
static_cast<MT>(in_pos[h_id * in_img_w * num_channels]) +
w1lambda * static_cast<MT>(in_pos[h_id * in_img_w * num_channels +
w_id * num_channels]));
}
}
......@@ -261,17 +263,18 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
const float ratio_h,
const float ratio_w,
const float align_type_value) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
int in_img_idx, in_img_idy, h_id, w_id;
T h1lambda, w1lambda, h2lambda, w2lambda;
T src_w = static_cast<T>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
T src_h = static_cast<T>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
MT h1lambda, w1lambda, h2lambda, w2lambda;
MT src_w = static_cast<MT>(ratio_w * (out_img_idx + align_type_value) -
align_type_value);
MT src_h = static_cast<MT>(ratio_h * (out_img_idy + align_type_value) -
align_type_value);
PreCalculatorForLinearInterpInputIndex(
&in_img_idx, &w_id, &w1lambda, &w2lambda, src_w, in_img_w);
......@@ -288,10 +291,12 @@ __global__ void KeBilinearInterpNCHWFw(const T* in,
if (out_img_idx < out_img_w && out_img_idy < out_img_h) {
while (nc_id < nc) {
const T* in_pos = &in[in_index];
out[out_index] =
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[w_id]) +
h1lambda * (w2lambda * in_pos[h_id * in_img_w] +
w1lambda * in_pos[h_id * in_img_w + w_id]);
out[out_index] = static_cast<T>(
h2lambda * (w2lambda * static_cast<MT>(in_pos[0]) +
w1lambda * static_cast<MT>(in_pos[w_id])) +
h1lambda *
(w2lambda * static_cast<MT>(in_pos[h_id * in_img_w]) +
w1lambda * static_cast<MT>(in_pos[h_id * in_img_w + w_id])));
in_index += in_index_stride;
out_index += out_index_stride;
......
......@@ -733,7 +733,7 @@ class TestBilinearInterpOpAPI_dy4(unittest.TestCase):
@unittest.skipIf(
not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestBilinearInterpOpForFloat16(unittest.TestCase):
class TestBilinearInterpOpZoomOutForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
......@@ -768,8 +768,52 @@ class TestBilinearInterpOpForFloat16(unittest.TestCase):
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2)
np.testing.assert_allclose(x_g_np_1, x_g_np_2)
np.testing.assert_allclose(y_np_1, y_np_2, atol=1e-3, rtol=1e-3)
# Since atomicAdd half will bring some diff, here we relax tolerance to 1e-2.
np.testing.assert_allclose(x_g_np_1, x_g_np_2, atol=1e-2, rtol=1e-2)
@unittest.skipIf(
not fluid.core.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestBilinearInterpOpZoomInForFloat16(unittest.TestCase):
def init_test_case(self):
self.interp_method = 'bilinear'
self.input_shape = [2, 3, 5, 5]
self.out_size = np.array([10, 10]).astype("int32")
self.align_corners = True
self.align_mode = 1
self.data_layout = 'NCHW'
def check_main(self, x_np, dtype):
paddle.disable_static()
x_np = x_np.astype(dtype)
x = paddle.to_tensor(x_np)
x.stop_gradient = False
y = interpolate(
x,
size=self.out_size.tolist(),
mode=self.interp_method,
align_mode=self.align_mode,
align_corners=self.align_corners,
data_format=self.data_layout,
)
x_g = paddle.grad(y, x)
y_np = y[0].numpy().astype('float32')
x_g_np = x_g[0].numpy().astype('float32')
paddle.enable_static()
return y_np, x_g_np
def test_main(self):
self.init_test_case()
x_np = np.random.random(self.input_shape).astype("float16")
y_np_1, x_g_np_1 = self.check_main(x_np, 'float16')
y_np_2, x_g_np_2 = self.check_main(x_np, 'float32')
np.testing.assert_allclose(y_np_1, y_np_2, atol=1e-3, rtol=1e-3)
# Since atomicAdd half will bring some diff, here we relax tolerance to 1e-2.
np.testing.assert_allclose(x_g_np_1, x_g_np_2, atol=1e-2, rtol=1e-2)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册