未验证 提交 fa67fb9f 编写于 作者: S shangliang Xu 提交者: GitHub

[dev] fix export model bug in DETR (#7120)

上级 6d6573b1
...@@ -42,9 +42,11 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco ...@@ -42,9 +42,11 @@ from utils import argsparser, Timer, get_current_memory_mb, multiclass_nms, coco
SUPPORT_MODELS = { SUPPORT_MODELS = {
'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE', 'YOLO', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet', 'S2ANet', 'JDE',
'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet', 'TOOD', 'RetinaNet',
'StrongBaseline', 'STGCN', 'YOLOX', 'PPHGNet', 'PPLCNet' 'StrongBaseline', 'STGCN', 'YOLOX', 'PPHGNet', 'PPLCNet', 'DETR'
} }
TUNED_TRT_DYNAMIC_MODELS = {'DETR'}
def bench_log(detector, img_list, model_info, batch_size=1, name=None): def bench_log(detector, img_list, model_info, batch_size=1, name=None):
mems = { mems = {
...@@ -103,6 +105,7 @@ class Detector(object): ...@@ -103,6 +105,7 @@ class Detector(object):
self.pred_config = self.set_config(model_dir) self.pred_config = self.set_config(model_dir)
self.predictor, self.config = load_predictor( self.predictor, self.config = load_predictor(
model_dir, model_dir,
self.pred_config.arch,
run_mode=run_mode, run_mode=run_mode,
batch_size=batch_size, batch_size=batch_size,
min_subgraph_size=self.pred_config.min_subgraph_size, min_subgraph_size=self.pred_config.min_subgraph_size,
...@@ -775,6 +778,7 @@ class PredictConfig(): ...@@ -775,6 +778,7 @@ class PredictConfig():
def load_predictor(model_dir, def load_predictor(model_dir,
arch,
run_mode='paddle', run_mode='paddle',
batch_size=1, batch_size=1,
device='CPU', device='CPU',
...@@ -787,7 +791,8 @@ def load_predictor(model_dir, ...@@ -787,7 +791,8 @@ def load_predictor(model_dir,
cpu_threads=1, cpu_threads=1,
enable_mkldnn=False, enable_mkldnn=False,
enable_mkldnn_bfloat16=False, enable_mkldnn_bfloat16=False,
delete_shuffle_pass=False): delete_shuffle_pass=False,
tuned_trt_shape_file="shape_range_info.pbtxt"):
"""set AnalysisConfig, generate AnalysisPredictor """set AnalysisConfig, generate AnalysisPredictor
Args: Args:
model_dir (str): root path of __model__ and __params__ model_dir (str): root path of __model__ and __params__
...@@ -854,6 +859,8 @@ def load_predictor(model_dir, ...@@ -854,6 +859,8 @@ def load_predictor(model_dir,
'trt_fp16': Config.Precision.Half 'trt_fp16': Config.Precision.Half
} }
if run_mode in precision_map.keys(): if run_mode in precision_map.keys():
if arch in TUNED_TRT_DYNAMIC_MODELS:
config.collect_shape_range_info(tuned_trt_shape_file)
config.enable_tensorrt_engine( config.enable_tensorrt_engine(
workspace_size=(1 << 25) * batch_size, workspace_size=(1 << 25) * batch_size,
max_batch_size=batch_size, max_batch_size=batch_size,
...@@ -861,6 +868,9 @@ def load_predictor(model_dir, ...@@ -861,6 +868,9 @@ def load_predictor(model_dir,
precision_mode=precision_map[run_mode], precision_mode=precision_map[run_mode],
use_static=False, use_static=False,
use_calib_mode=trt_calib_mode) use_calib_mode=trt_calib_mode)
if arch in TUNED_TRT_DYNAMIC_MODELS:
config.enable_tuned_tensorrt_dynamic_shape(tuned_trt_shape_file,
True)
if use_dynamic_shape: if use_dynamic_shape:
min_input_shape = { min_input_shape = {
......
...@@ -50,6 +50,7 @@ TRT_MIN_SUBGRAPH = { ...@@ -50,6 +50,7 @@ TRT_MIN_SUBGRAPH = {
'TOOD': 5, 'TOOD': 5,
'YOLOX': 8, 'YOLOX': 8,
'METRO_Body': 3, 'METRO_Body': 3,
'DETR': 3,
} }
KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet'] KEYPOINT_ARCH = ['HigherHRNet', 'TopDownHRNet']
...@@ -134,7 +135,6 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -134,7 +135,6 @@ def _dump_infer_config(config, path, image_shape, model):
export_onnx = config.get('export_onnx', False) export_onnx = config.get('export_onnx', False)
export_eb = config.get('export_eb', False) export_eb = config.get('export_eb', False)
infer_arch = config['architecture'] infer_arch = config['architecture']
if 'RCNN' in infer_arch and export_onnx: if 'RCNN' in infer_arch and export_onnx:
logger.warning( logger.warning(
...@@ -142,7 +142,6 @@ def _dump_infer_config(config, path, image_shape, model): ...@@ -142,7 +142,6 @@ def _dump_infer_config(config, path, image_shape, model):
infer_cfg['export_onnx'] = True infer_cfg['export_onnx'] = True
infer_cfg['export_eb'] = export_eb infer_cfg['export_eb'] = export_eb
if infer_arch in MOT_ARCH: if infer_arch in MOT_ARCH:
if infer_arch == 'DeepSORT': if infer_arch == 'DeepSORT':
tracker_cfg = config['DeepSORTTracker'] tracker_cfg = config['DeepSORTTracker']
......
...@@ -27,17 +27,20 @@ __all__ = ['DETR'] ...@@ -27,17 +27,20 @@ __all__ = ['DETR']
class DETR(BaseArch): class DETR(BaseArch):
__category__ = 'architecture' __category__ = 'architecture'
__inject__ = ['post_process'] __inject__ = ['post_process']
__shared__ = ['exclude_post_process']
def __init__(self, def __init__(self,
backbone, backbone,
transformer, transformer,
detr_head, detr_head,
post_process='DETRBBoxPostProcess'): post_process='DETRBBoxPostProcess',
exclude_post_process=False):
super(DETR, self).__init__() super(DETR, self).__init__()
self.backbone = backbone self.backbone = backbone
self.transformer = transformer self.transformer = transformer
self.detr_head = detr_head self.detr_head = detr_head
self.post_process = post_process self.post_process = post_process
self.exclude_post_process = exclude_post_process
@classmethod @classmethod
def from_config(cls, cfg, *args, **kwargs): def from_config(cls, cfg, *args, **kwargs):
...@@ -65,15 +68,20 @@ class DETR(BaseArch): ...@@ -65,15 +68,20 @@ class DETR(BaseArch):
body_feats = self.backbone(self.inputs) body_feats = self.backbone(self.inputs)
# Transformer # Transformer
out_transformer = self.transformer(body_feats, self.inputs['pad_mask']) pad_mask = self.inputs['pad_mask'] if self.training else None
out_transformer = self.transformer(body_feats, pad_mask)
# DETR Head # DETR Head
if self.training: if self.training:
return self.detr_head(out_transformer, body_feats, self.inputs) return self.detr_head(out_transformer, body_feats, self.inputs)
else: else:
preds = self.detr_head(out_transformer, body_feats) preds = self.detr_head(out_transformer, body_feats)
bbox, bbox_num = self.post_process(preds, self.inputs['im_shape'], if self.exclude_post_process:
self.inputs['scale_factor']) bboxes, logits, masks = preds
return bboxes, logits
else:
bbox, bbox_num = self.post_process(
preds, self.inputs['im_shape'], self.inputs['scale_factor'])
return bbox, bbox_num return bbox, bbox_num
def get_loss(self, ): def get_loss(self, ):
......
...@@ -479,9 +479,9 @@ class DETRBBoxPostProcess(object): ...@@ -479,9 +479,9 @@ class DETRBBoxPostProcess(object):
bbox_pred = bbox_cxcywh_to_xyxy(bboxes) bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
origin_shape = paddle.floor(im_shape / scale_factor + 0.5) origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
img_h, img_w = origin_shape.unbind(1) img_h, img_w = paddle.split(origin_shape, 2, axis=-1)
origin_shape = paddle.stack( origin_shape = paddle.concat(
[img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0) [img_w, img_h, img_w, img_h], axis=-1).reshape([-1, 1, 4])
bbox_pred *= origin_shape bbox_pred *= origin_shape
scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax( scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
......
...@@ -69,8 +69,6 @@ class TransformerEncoderLayer(nn.Layer): ...@@ -69,8 +69,6 @@ class TransformerEncoderLayer(nn.Layer):
return tensor if pos_embed is None else tensor + pos_embed return tensor if pos_embed is None else tensor + pos_embed
def forward(self, src, src_mask=None, pos_embed=None): def forward(self, src, src_mask=None, pos_embed=None):
src_mask = _convert_attention_mask(src_mask, src.dtype)
residual = src residual = src
if self.normalize_before: if self.normalize_before:
src = self.norm1(src) src = self.norm1(src)
...@@ -99,8 +97,6 @@ class TransformerEncoder(nn.Layer): ...@@ -99,8 +97,6 @@ class TransformerEncoder(nn.Layer):
self.norm = norm self.norm = norm
def forward(self, src, src_mask=None, pos_embed=None): def forward(self, src, src_mask=None, pos_embed=None):
src_mask = _convert_attention_mask(src_mask, src.dtype)
output = src output = src
for layer in self.layers: for layer in self.layers:
output = layer(output, src_mask=src_mask, pos_embed=pos_embed) output = layer(output, src_mask=src_mask, pos_embed=pos_embed)
...@@ -158,7 +154,6 @@ class TransformerDecoderLayer(nn.Layer): ...@@ -158,7 +154,6 @@ class TransformerDecoderLayer(nn.Layer):
pos_embed=None, pos_embed=None,
query_pos_embed=None): query_pos_embed=None):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
residual = tgt residual = tgt
if self.normalize_before: if self.normalize_before:
...@@ -209,7 +204,6 @@ class TransformerDecoder(nn.Layer): ...@@ -209,7 +204,6 @@ class TransformerDecoder(nn.Layer):
pos_embed=None, pos_embed=None,
query_pos_embed=None): query_pos_embed=None):
tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype) tgt_mask = _convert_attention_mask(tgt_mask, tgt.dtype)
memory_mask = _convert_attention_mask(memory_mask, memory.dtype)
output = tgt output = tgt
intermediate = [] intermediate = []
...@@ -298,6 +292,9 @@ class DETRTransformer(nn.Layer): ...@@ -298,6 +292,9 @@ class DETRTransformer(nn.Layer):
'backbone_num_channels': [i.channels for i in input_shape][-1], 'backbone_num_channels': [i.channels for i in input_shape][-1],
} }
def _convert_attention_mask(self, mask):
return (mask - 1.0) * 1e9
def forward(self, src, src_mask=None): def forward(self, src, src_mask=None):
r""" r"""
Applies a Transformer model on the inputs. Applies a Transformer model on the inputs.
...@@ -321,20 +318,21 @@ class DETRTransformer(nn.Layer): ...@@ -321,20 +318,21 @@ class DETRTransformer(nn.Layer):
""" """
# use last level feature map # use last level feature map
src_proj = self.input_proj(src[-1]) src_proj = self.input_proj(src[-1])
bs, c, h, w = src_proj.shape bs, c, h, w = paddle.shape(src_proj)
# flatten [B, C, H, W] to [B, HxW, C] # flatten [B, C, H, W] to [B, HxW, C]
src_flatten = src_proj.flatten(2).transpose([0, 2, 1]) src_flatten = src_proj.flatten(2).transpose([0, 2, 1])
if src_mask is not None: if src_mask is not None:
src_mask = F.interpolate( src_mask = F.interpolate(src_mask.unsqueeze(0), size=(h, w))[0]
src_mask.unsqueeze(0).astype(src_flatten.dtype),
size=(h, w))[0].astype('bool')
else: else:
src_mask = paddle.ones([bs, h, w], dtype='bool') src_mask = paddle.ones([bs, h, w])
pos_embed = self.position_embedding(src_mask).flatten(2).transpose( pos_embed = self.position_embedding(src_mask).flatten(2).transpose(
[0, 2, 1]) [0, 2, 1])
src_mask = _convert_attention_mask(src_mask, src_flatten.dtype) if self.training:
src_mask = src_mask.reshape([bs, 1, 1, -1]) src_mask = self._convert_attention_mask(src_mask)
src_mask = src_mask.reshape([bs, 1, 1, h * w])
else:
src_mask = None
memory = self.encoder( memory = self.encoder(
src_flatten, src_mask=src_mask, pos_embed=pos_embed) src_flatten, src_mask=src_mask, pos_embed=pos_embed)
...@@ -349,5 +347,10 @@ class DETRTransformer(nn.Layer): ...@@ -349,5 +347,10 @@ class DETRTransformer(nn.Layer):
pos_embed=pos_embed, pos_embed=pos_embed,
query_pos_embed=query_pos_embed) query_pos_embed=query_pos_embed)
if self.training:
src_mask = src_mask.reshape([bs, 1, 1, h, w])
else:
src_mask = None
return (output, memory.transpose([0, 2, 1]).reshape([bs, c, h, w]), return (output, memory.transpose([0, 2, 1]).reshape([bs, c, h, w]),
src_proj, src_mask.reshape([bs, 1, 1, h, w])) src_proj, src_mask)
...@@ -65,11 +65,9 @@ class PositionEmbedding(nn.Layer): ...@@ -65,11 +65,9 @@ class PositionEmbedding(nn.Layer):
Returns: Returns:
pos (Tensor): [B, C, H, W] pos (Tensor): [B, C, H, W]
""" """
assert mask.dtype == paddle.bool
if self.embed_type == 'sine': if self.embed_type == 'sine':
mask = mask.astype('float32') y_embed = mask.cumsum(1)
y_embed = mask.cumsum(1, dtype='float32') x_embed = mask.cumsum(2)
x_embed = mask.cumsum(2, dtype='float32')
if self.normalize: if self.normalize:
y_embed = (y_embed + self.offset) / ( y_embed = (y_embed + self.offset) / (
y_embed[:, -1:, :] + self.eps) * self.scale y_embed[:, -1:, :] + self.eps) * self.scale
...@@ -101,8 +99,7 @@ class PositionEmbedding(nn.Layer): ...@@ -101,8 +99,7 @@ class PositionEmbedding(nn.Layer):
x_emb.unsqueeze(0).repeat(h, 1, 1), x_emb.unsqueeze(0).repeat(h, 1, 1),
y_emb.unsqueeze(1).repeat(1, w, 1), y_emb.unsqueeze(1).repeat(1, w, 1),
], ],
axis=-1).transpose([2, 0, 1]).unsqueeze(0).tile(mask.shape[0], axis=-1).transpose([2, 0, 1]).unsqueeze(0)
1, 1, 1)
return pos return pos
else: else:
raise ValueError(f"not supported {self.embed_type}") raise ValueError(f"not supported {self.embed_type}")
...@@ -38,15 +38,15 @@ def _get_clones(module, N): ...@@ -38,15 +38,15 @@ def _get_clones(module, N):
def bbox_cxcywh_to_xyxy(x): def bbox_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1) x_c, y_c, w, h = x.split(4, axis=-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return paddle.stack(b, axis=-1) return paddle.concat(b, axis=-1)
def bbox_xyxy_to_cxcywh(x): def bbox_xyxy_to_cxcywh(x):
x0, y0, x1, y1 = x.unbind(-1) x0, y0, x1, y1 = x.split(4, axis=-1)
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)] b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return paddle.stack(b, axis=-1) return paddle.concat(b, axis=-1)
def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0): def sigmoid_focal_loss(logit, label, normalizer=1.0, alpha=0.25, gamma=2.0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册