未验证 提交 eb7b0ddd 编写于 作者: Y Yang Zhang 提交者: GitHub

Force `cudnn` backend for depthwise convs when fp16 is enabled (#270)

上级 021a13c7
......@@ -19,6 +19,7 @@ from __future__ import print_function
from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['BlazeNet']
......@@ -151,6 +152,7 @@ class BlazeNet(object):
use_pool = not stride == 1
use_double_block = double_channels is not None
act = 'relu' if use_double_block else None
mixed_precision_enabled = mixed_precision_global_state() is not None
if use_5x5kernel:
conv_dw = self._conv_norm(
......@@ -160,7 +162,7 @@ class BlazeNet(object):
stride=stride,
padding=2,
num_groups=in_channels,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw")
else:
conv_dw_1 = self._conv_norm(
......@@ -170,7 +172,7 @@ class BlazeNet(object):
stride=1,
padding=1,
num_groups=in_channels,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw_1")
conv_dw = self._conv_norm(
input=conv_dw_1,
......@@ -179,7 +181,7 @@ class BlazeNet(object):
stride=stride,
padding=1,
num_groups=in_channels,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "1_dw_2")
conv_pw = self._conv_norm(
......@@ -199,7 +201,7 @@ class BlazeNet(object):
num_filters=out_channels,
stride=1,
padding=2,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw")
else:
conv_dw_1 = self._conv_norm(
......@@ -209,7 +211,7 @@ class BlazeNet(object):
stride=1,
padding=1,
num_groups=out_channels,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw_1")
conv_dw = self._conv_norm(
input=conv_dw_1,
......@@ -218,7 +220,7 @@ class BlazeNet(object):
stride=1,
padding=1,
num_groups=out_channels,
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "2_dw_2")
conv_pw = self._conv_norm(
......
......@@ -20,6 +20,7 @@ from paddle import fluid
from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.regularizer import L2Decay
from ppdet.experimental import mixed_precision_global_state
from ppdet.core.workspace import register
__all__ = ['MobileNet']
......@@ -104,6 +105,7 @@ class MobileNet(object):
stride,
scale,
name=None):
mixed_precision_enabled = mixed_precision_global_state() is not None
depthwise_conv = self._conv_norm(
input=input,
filter_size=3,
......@@ -111,7 +113,7 @@ class MobileNet(object):
stride=stride,
padding=1,
num_groups=int(num_groups * scale),
use_cudnn=False,
use_cudnn=mixed_precision_enabled,
name=name + "_dw")
pointwise_conv = self._conv_norm(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册