conv.py 1.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
import paddle
from paddle import nn
from paddle.nn import functional as F


class ConvBNRelu(nn.Layer):
    def __init__(self,
                 cin,
                 cout,
                 kernel_size,
                 stride,
                 padding,
                 residual=False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.Conv2D(cin, cout, kernel_size, stride, padding),
            nn.BatchNorm2D(cout))
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)


class NonNormConv2d(nn.Layer):
    def __init__(self,
                 cin,
                 cout,
                 kernel_size,
                 stride,
                 padding,
                 residual=False,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.Conv2D(cin, cout, kernel_size, stride, padding), )
        self.act = nn.LeakyReLU(0.01, inplace=True)

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)


class Conv2dTranspseRelu(nn.Layer):
    def __init__(self,
                 cin,
                 cout,
                 kernel_size,
                 stride,
                 padding,
                 output_padding=0,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
            nn.ConvTranspose2D(cin, cout, kernel_size, stride, padding,
                               output_padding), nn.BatchNorm2D(cout))
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)