提交 100387be 编写于 作者: Y Yang Zhang 提交者: GitHub

Use fp16 for RCNN bbox head (#3475)

上级 435e6f54
...@@ -27,6 +27,7 @@ from paddle.fluid.initializer import MSRA ...@@ -27,6 +27,7 @@ from paddle.fluid.initializer import MSRA
from ppdet.modeling.ops import MultiClassNMS from ppdet.modeling.ops import MultiClassNMS
from ppdet.modeling.ops import ConvNorm from ppdet.modeling.ops import ConvNorm
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from ppdet.experimental import mixed_precision_global_state
__all__ = ['BBoxHead', 'TwoFCHead', 'XConvNormHead'] __all__ = ['BBoxHead', 'TwoFCHead', 'XConvNormHead']
...@@ -120,6 +121,12 @@ class TwoFCHead(object): ...@@ -120,6 +121,12 @@ class TwoFCHead(object):
def __call__(self, roi_feat): def __call__(self, roi_feat):
fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3] fan = roi_feat.shape[1] * roi_feat.shape[2] * roi_feat.shape[3]
mixed_precision_enabled = mixed_precision_global_state() is not None
if mixed_precision_enabled:
roi_feat = fluid.layers.cast(roi_feat, 'float16')
fc6 = fluid.layers.fc(input=roi_feat, fc6 = fluid.layers.fc(input=roi_feat,
size=self.mlp_dim, size=self.mlp_dim,
act='relu', act='relu',
...@@ -141,6 +148,10 @@ class TwoFCHead(object): ...@@ -141,6 +148,10 @@ class TwoFCHead(object):
name='fc7_b', name='fc7_b',
learning_rate=2., learning_rate=2.,
regularizer=L2Decay(0.))) regularizer=L2Decay(0.)))
if mixed_precision_enabled:
head_feat = fluid.layers.cast(head_feat, 'float32')
return head_feat return head_feat
...@@ -280,7 +291,7 @@ class BBoxHead(object): ...@@ -280,7 +291,7 @@ class BBoxHead(object):
number of input images, each element consists of im_height, number of input images, each element consists of im_height,
im_width, im_scale. im_width, im_scale.
im_shape (Variable): Actual shape of original image with shape im_shape (Variable): Actual shape of original image with shape
[B, 3]. B is the number of images, each element consists of [B, 3]. B is the number of images, each element consists of
original_height, original_width, 1 original_height, original_width, 1
Returns: Returns:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册