未验证 提交 42516643 编写于 作者: T topduke 提交者: GitHub

Fix grid_sample data type bug when use fp16 (#9930)

* fix gris_sample data type bug when use fp16

* fix gris_sample data type bug when use fp16

* fix v4rec batchsize
上级 24ff4def
......@@ -8,7 +8,7 @@ Global:
save_epoch_step: 10
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model:
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
......@@ -101,7 +101,7 @@ Train:
sampler:
name: MultiScaleSampler
scales: [[320, 32], [320, 48], [320, 64]]
first_bs: &bs 128
first_bs: &bs 192
fix_bs: false
divided_factor: [8, 16] # w, h
is_training: True
......
......@@ -280,5 +280,13 @@ class GA_SPIN_Transformer(nn.Layer):
x = self.sp_net(x, sp_weight, offsets, lambda_color)
if self.stn:
is_fp16 = False
if build_P_prime_reshape.dtype != paddle.float32:
data_type = build_P_prime_reshape.dtype
x = x.cast(paddle.float32)
build_P_prime_reshape = build_P_prime_reshape.cast(paddle.float32)
is_fp16 = True
x = F.grid_sample(x=x, grid=build_P_prime_reshape, padding_mode='border')
if is_fp16:
x = x.cast(data_type)
return x
......@@ -304,5 +304,14 @@ class TPS(nn.Layer):
batch_P_prime = self.grid_generator(batch_C_prime, image.shape[2:])
batch_P_prime = batch_P_prime.reshape(
[-1, image.shape[2], image.shape[3], 2])
is_fp16 = False
if batch_P_prime.dtype != paddle.float32:
data_type = batch_P_prime.dtype
image = image.cast(paddle.float32)
batch_P_prime = batch_P_prime.cast(paddle.float32)
is_fp16 = True
batch_I_r = F.grid_sample(x=image, grid=batch_P_prime)
if is_fp16:
batch_I_r = batch_I_r.cast(data_type)
return batch_I_r
......@@ -29,12 +29,28 @@ import itertools
def grid_sample(input, grid, canvas=None):
input.stop_gradient = False
is_fp16 = False
if grid.dtype != paddle.float32:
data_type = grid.dtype
input = input.cast(paddle.float32)
grid = grid.cast(paddle.float32)
is_fp16 = True
output = F.grid_sample(input, grid)
if is_fp16:
output = output.cast(data_type)
grid = grid.cast(data_type)
if canvas is None:
return output
else:
input_mask = paddle.ones(shape=input.shape)
if is_fp16:
input_mask = input_mask.cast(paddle.float32)
grid = grid.cast(paddle.float32)
output_mask = F.grid_sample(input_mask, grid)
if is_fp16:
output_mask = output_mask.cast(data_type)
padded_output = output * output_mask + canvas * (1 - output_mask)
return padded_output
......
......@@ -187,7 +187,7 @@ def export_single_model(model,
shape=[None] + infer_shape, dtype="float32")
])
if arch_config["Backbone"]["name"] == "LCNetv3":
if arch_config["Backbone"]["name"] == "PPLCNetV3":
# for rep lcnetv3
for layer in model.sublayers():
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册