import paddle import paddle.nn as nn import paddle.nn.functional as F class TempConv(nn.Layer): def __init__(self, in_planes, out_planes, kernel_size=(1,3,3), stride=(1,1,1), padding=(0,1,1) ): super(TempConv, self).__init__() self.conv3d = nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding) self.bn = nn.BatchNorm( out_planes ) def forward(self, x): return F.elu( self.bn(self.conv3d(x))) class Upsample(nn.Layer): def __init__(self, in_planes, out_planes, scale_factor=(1,2,2)): super(Upsample, self).__init__() self.scale_factor = scale_factor self.conv3d = nn.Conv3d( in_planes, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) self.bn = nn.BatchNorm( out_planes ) def forward(self, x): out_size = x.shape[2:] for i in range(3): out_size[i] = self.scale_factor[i] * out_size[i] return F.elu( self.bn( self.conv3d( F.interpolate(x, size=out_size, mode='trilinear', align_corners=False, data_format='NCDHW', align_mode=0)))) class UpsampleConcat(nn.Layer): def __init__(self, in_planes_up, in_planes_flat, out_planes): super(UpsampleConcat, self).__init__() self.conv3d = TempConv( in_planes_up + in_planes_flat, out_planes, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) def forward(self, x1, x2): scale_factor=(1,2,2) out_size = x1.shape[2:] for i in range(3): out_size[i] = scale_factor[i] * out_size[i] x1 = F.interpolate(x1, size=out_size, mode='trilinear', align_corners=False, data_format='NCDHW', align_mode=0) x = paddle.concat([x1, x2], axis=1) return self.conv3d(x) class SourceReferenceAttention(paddle.fluid.dygraph.Layer): """ Source-Reference Attention Layer """ def __init__(self, in_planes_s, in_planes_r): """ Parameters ---------- in_planes_s: int Number of input source feature vector channels. in_planes_r: int Number of input reference feature vector channels. """ super(SourceReferenceAttention,self).__init__() self.query_conv = nn.Conv3d(in_channels=in_planes_s, out_channels=in_planes_s//8, kernel_size=1 ) self.key_conv = nn.Conv3d(in_channels=in_planes_r, out_channels=in_planes_r//8, kernel_size=1 ) self.value_conv = nn.Conv3d(in_channels=in_planes_r, out_channels=in_planes_r, kernel_size=1 ) self.gamma = self.create_parameter(shape=[1], dtype=self.query_conv.weight.dtype, default_initializer=paddle.fluid.initializer.Constant(0.0)) def forward(self, source, reference): s_batchsize, sC, sT, sH, sW = source.shape r_batchsize, rC, rT, rH, rW = reference.shape proj_query = paddle.reshape(self.query_conv(source), [s_batchsize,-1,sT*sH*sW]) proj_query = paddle.transpose(proj_query, [0, 2, 1]) proj_key = paddle.reshape(self.key_conv(reference), [r_batchsize,-1,rT*rW*rH]) energy = paddle.bmm( proj_query, proj_key ) attention = F.softmax(energy) proj_value = paddle.reshape(self.value_conv(reference), [r_batchsize,-1,rT*rH*rW]) out = paddle.bmm(proj_value,paddle.transpose(attention, [0,2,1])) out = paddle.reshape(out, [s_batchsize, sC, sT, sH, sW]) out = self.gamma*out + source return out, attention class NetworkR( nn.Layer ): def __init__(self): super(NetworkR, self).__init__() self.layers = nn.Sequential( nn.ReplicationPad3d((1,1,1,1,1,1)), TempConv( 1, 64, kernel_size=(3,3,3), stride=(1,2,2), padding=(0,0,0) ), TempConv( 64, 128, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 128, 128, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 128, 256, kernel_size=(3,3,3), stride=(1,2,2), padding=(1,1,1) ), TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 256, 256, kernel_size=(3,3,3), padding=(1,1,1) ), Upsample( 256, 128 ), TempConv( 128, 64, kernel_size=(3,3,3), padding=(1,1,1) ), TempConv( 64, 64, kernel_size=(3,3,3), padding=(1,1,1) ), Upsample( 64, 16 ), nn.Conv3d( 16, 1, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) ) def forward(self, x): return paddle.clip((x + paddle.fluid.layers.tanh( self.layers( ((x * 1).detach())-0.4462414 ) )), 0.0, 1.0) class NetworkC( nn.Layer ): def __init__(self): super(NetworkC, self).__init__() self.down1 = nn.Sequential( nn.ReplicationPad3d((1,1,1,1,0,0)), TempConv( 1, 64, stride=(1,2,2), padding=(0,0,0) ), TempConv( 64, 128 ), TempConv( 128, 128 ), TempConv( 128, 256, stride=(1,2,2) ), TempConv( 256, 256 ), TempConv( 256, 256 ), TempConv( 256, 512, stride=(1,2,2) ), TempConv( 512, 512 ), TempConv( 512, 512 ) ) self.flat = nn.Sequential( TempConv( 512, 512 ), TempConv( 512, 512 ) ) self.down2 = nn.Sequential( TempConv( 512, 512, stride=(1,2,2) ), TempConv( 512, 512 ), ) self.stattn1 = SourceReferenceAttention( 512, 512 ) # Source-Reference Attention self.stattn2 = SourceReferenceAttention( 512, 512 ) # Source-Reference Attention self.selfattn1 = SourceReferenceAttention( 512, 512 ) # Self Attention self.conv1 = TempConv( 512, 512 ) self.up1 = UpsampleConcat( 512, 512, 512 ) # 1/8 self.selfattn2 = SourceReferenceAttention( 512, 512 ) # Self Attention self.conv2 = TempConv( 512, 256, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) self.up2 = nn.Sequential( Upsample( 256, 128 ), # 1/4 TempConv( 128, 64, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) ) self.up3 = nn.Sequential( Upsample( 64, 32 ), # 1/2 TempConv( 32, 16, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) ) self.up4 = nn.Sequential( Upsample( 16, 8 ), # 1/1 nn.Conv3d( 8, 2, kernel_size=(3,3,3), stride=(1,1,1), padding=(1,1,1) ) ) self.reffeatnet1 = nn.Sequential( TempConv( 3, 64, stride=(1,2,2) ), TempConv( 64, 128 ), TempConv( 128, 128 ), TempConv( 128, 256, stride=(1,2,2) ), TempConv( 256, 256 ), TempConv( 256, 256 ), TempConv( 256, 512, stride=(1,2,2) ), TempConv( 512, 512 ), TempConv( 512, 512 ), ) self.reffeatnet2 = nn.Sequential( TempConv( 512, 512, stride=(1,2,2) ), TempConv( 512, 512 ), TempConv( 512, 512 ), ) def forward(self, x, x_refs=None): x1 = self.down1( x - 0.4462414 ) if x_refs is not None: x_refs = paddle.transpose(x_refs, [0, 2, 1, 3, 4]) # [B,T,C,H,W] --> [B,C,T,H,W] reffeat = self.reffeatnet1( x_refs-0.48 ) x1, _ = self.stattn1( x1, reffeat ) x2 = self.flat( x1 ) out = self.down2( x1 ) if x_refs is not None: reffeat2 = self.reffeatnet2( reffeat ) out, _ = self.stattn2( out, reffeat2 ) out = self.conv1( out ) out, _ = self.selfattn1( out, out ) out = self.up1( out, x2 ) out, _ = self.selfattn2( out, out ) out = self.conv2( out ) out = self.up2( out ) out = self.up3( out ) out = self.up4( out ) return F.sigmoid( out )