未验证 提交 eb7b0ddd 编写于 作者: Y Yang Zhang 提交者: GitHub

Force `cudnn` backend for depthwise convs when fp16 is enabled (#270)

上级 021a13c7
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
from paddle import fluid from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register from ppdet.core.workspace import register
__all__ = ['BlazeNet'] __all__ = ['BlazeNet']
...@@ -151,6 +152,7 @@ class BlazeNet(object): ...@@ -151,6 +152,7 @@ class BlazeNet(object):
use_pool = not stride == 1 use_pool = not stride == 1
use_double_block = double_channels is not None use_double_block = double_channels is not None
act = 'relu' if use_double_block else None act = 'relu' if use_double_block else None
mixed_precision_enabled = mixed_precision_global_state() is not None
if use_5x5kernel: if use_5x5kernel:
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
...@@ -160,7 +162,7 @@ class BlazeNet(object): ...@@ -160,7 +162,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=2, padding=2,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "1_dw") name=name + "1_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -170,7 +172,7 @@ class BlazeNet(object): ...@@ -170,7 +172,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "1_dw_1") name=name + "1_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -179,7 +181,7 @@ class BlazeNet(object): ...@@ -179,7 +181,7 @@ class BlazeNet(object):
stride=stride, stride=stride,
padding=1, padding=1,
num_groups=in_channels, num_groups=in_channels,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "1_dw_2") name=name + "1_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
...@@ -199,7 +201,7 @@ class BlazeNet(object): ...@@ -199,7 +201,7 @@ class BlazeNet(object):
num_filters=out_channels, num_filters=out_channels,
stride=1, stride=1,
padding=2, padding=2,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "2_dw") name=name + "2_dw")
else: else:
conv_dw_1 = self._conv_norm( conv_dw_1 = self._conv_norm(
...@@ -209,7 +211,7 @@ class BlazeNet(object): ...@@ -209,7 +211,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "2_dw_1") name=name + "2_dw_1")
conv_dw = self._conv_norm( conv_dw = self._conv_norm(
input=conv_dw_1, input=conv_dw_1,
...@@ -218,7 +220,7 @@ class BlazeNet(object): ...@@ -218,7 +220,7 @@ class BlazeNet(object):
stride=1, stride=1,
padding=1, padding=1,
num_groups=out_channels, num_groups=out_channels,
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "2_dw_2") name=name + "2_dw_2")
conv_pw = self._conv_norm( conv_pw = self._conv_norm(
......
...@@ -20,6 +20,7 @@ from paddle import fluid ...@@ -20,6 +20,7 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay from paddle.fluid.regularizer import L2Decay
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register from ppdet.core.workspace import register
__all__ = ['MobileNet'] __all__ = ['MobileNet']
...@@ -104,6 +105,7 @@ class MobileNet(object): ...@@ -104,6 +105,7 @@ class MobileNet(object):
stride, stride,
scale, scale,
name=None): name=None):
mixed_precision_enabled = mixed_precision_global_state() is not None
depthwise_conv = self._conv_norm( depthwise_conv = self._conv_norm(
input=input, input=input,
filter_size=3, filter_size=3,
...@@ -111,7 +113,7 @@ class MobileNet(object): ...@@ -111,7 +113,7 @@ class MobileNet(object):
stride=stride, stride=stride,
padding=1, padding=1,
num_groups=int(num_groups * scale), num_groups=int(num_groups * scale),
use_cudnn=False, use_cudnn=mixed_precision_enabled,
name=name + "_dw") name=name + "_dw")
pointwise_conv = self._conv_norm( pointwise_conv = self._conv_norm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册