未验证 提交 20a611cf 编写于 作者: Y Yang Zhang 提交者: GitHub

Fix mixed precision training of senet backbone (#8)

上级 382a8027
...@@ -21,6 +21,7 @@ import math ...@@ -21,6 +21,7 @@ import math
from paddle import fluid from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register, serializable from ppdet.core.workspace import register, serializable
from .resnext import ResNeXt from .resnext import ResNeXt
...@@ -72,12 +73,13 @@ class SENet(ResNeXt): ...@@ -72,12 +73,13 @@ class SENet(ResNeXt):
self.dcn_v2_stages = dcn_v2_stages self.dcn_v2_stages = dcn_v2_stages
def _squeeze_excitation(self, input, num_channels, name=None): def _squeeze_excitation(self, input, num_channels, name=None):
mixed_precision_enabled = mixed_precision_global_state() is not None
pool = fluid.layers.pool2d( pool = fluid.layers.pool2d(
input=input, input=input,
pool_size=0, pool_size=0,
pool_type='avg', pool_type='avg',
global_pooling=True, global_pooling=True,
use_cudnn=False) use_cudnn=mixed_precision_enabled)
stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0) stdv = 1.0 / math.sqrt(pool.shape[1] * 1.0)
squeeze = fluid.layers.fc( squeeze = fluid.layers.fc(
input=pool, input=pool,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册