未验证 提交 8f77b383 编写于 作者: M michaelowenliu 提交者: GitHub

Merge pull request #395 from wuyefeilin/dygraph

update layer_libs.py
......@@ -20,18 +20,18 @@ from paddle.nn import Conv2d
from paddle.nn import SyncBatchNorm as BatchNorm
class ConvBNRelu(nn.Layer):
class ConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvBNRelu, self).__init__()
super(ConvBNReLU, self).__init__()
self.conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self.batch_norm = BatchNorm(out_channels)
self._batch_norm = BatchNorm(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = self._conv(x)
x = self._batch_norm(x)
x = F.relu(x)
return x
......@@ -40,14 +40,12 @@ class ConvBN(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(ConvBN, self).__init__()
self.conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self.batch_norm = BatchNorm(out_channels)
self._conv = Conv2d(in_channels, out_channels, kernel_size, **kwargs)
self._batch_norm = BatchNorm(out_channels)
def forward(self, x):
x = self.conv(x)
x = self.batch_norm(x)
x = self._conv(x)
x = self._batch_norm(x)
return x
......@@ -69,16 +67,16 @@ class ConvReluPool(nn.Layer):
return x
class DepthwiseConvBNRelu(nn.Layer):
class DepthwiseConvBNReLU(nn.Layer):
def __init__(self, in_channels, out_channels, kernel_size, **kwargs):
super(DepthwiseConvBNRelu, self).__init__()
super(DepthwiseConvBNReLU, self).__init__()
self.depthwise_conv = ConvBN(
in_channels,
out_channels=in_channels,
kernel_size=kernel_size,
groups=in_channels,
**kwargs)
self.piontwise_conv = ConvBNRelu(
self.piontwise_conv = ConvBNReLU(
in_channels, out_channels, kernel_size=1, groups=1)
def forward(self, x):
......@@ -105,7 +103,7 @@ class AuxLayer(nn.Layer):
dropout_prob=0.1):
super(AuxLayer, self).__init__()
self.conv_bn_relu = ConvBNRelu(
self.conv_bn_relu = ConvBNReLU(
in_channels=in_channels,
out_channels=inter_channels,
kernel_size=3,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册