未验证 提交 d94481c2 编写于 作者: X Xinjun-Wu吴鑫俊 提交者: GitHub

modify the ch_in & ch_out expression of the BasicBlock (#5075)

* modify the ch_in&ch_out of the BasicBlock

(cherry picked from commit da568e4741f9a3a440ba1f605b8d12299fe4ed23)

* add the type convertion of the ch_in and ch_out & update gitignore

(cherry picked from commit 5b55097c6dab81dbccc187ba13a92b724ffc7ef5)

* revert the changes of gitignore

(cherry picked from commit ffeb72a8cefca6f416c5483756cc67f26f40522f)
上级 fbf981d1
...@@ -149,9 +149,14 @@ class BasicBlock(nn.Layer): ...@@ -149,9 +149,14 @@ class BasicBlock(nn.Layer):
super(BasicBlock, self).__init__() super(BasicBlock, self).__init__()
assert ch_in == ch_out and (ch_in % 2) == 0, \
f"ch_in and ch_out should be the same even int, but the input \'ch_in is {ch_in}, \'ch_out is {ch_out}"
# example:
# --------------{conv1} --> {conv2}
# channel route: 10-->5 --> 5-->10
self.conv1 = ConvBNLayer( self.conv1 = ConvBNLayer(
ch_in=ch_in, ch_in=ch_in,
ch_out=ch_out, ch_out=int(ch_out/2),
filter_size=1, filter_size=1,
stride=1, stride=1,
padding=0, padding=0,
...@@ -160,8 +165,8 @@ class BasicBlock(nn.Layer): ...@@ -160,8 +165,8 @@ class BasicBlock(nn.Layer):
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
data_format=data_format) data_format=data_format)
self.conv2 = ConvBNLayer( self.conv2 = ConvBNLayer(
ch_in=ch_out, ch_in=int(ch_out/2),
ch_out=ch_out * 2, ch_out=ch_out ,
filter_size=3, filter_size=3,
stride=1, stride=1,
padding=1, padding=1,
...@@ -215,7 +220,7 @@ class Blocks(nn.Layer): ...@@ -215,7 +220,7 @@ class Blocks(nn.Layer):
res_out = self.add_sublayer( res_out = self.add_sublayer(
block_name, block_name,
BasicBlock( BasicBlock(
ch_out * 2, ch_out,
ch_out, ch_out,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
...@@ -296,7 +301,7 @@ class DarkNet(nn.Layer): ...@@ -296,7 +301,7 @@ class DarkNet(nn.Layer):
name, name,
Blocks( Blocks(
int(ch_in[i]), int(ch_in[i]),
32 * (2**i), int(ch_in[i]),
stage, stage,
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
...@@ -305,14 +310,14 @@ class DarkNet(nn.Layer): ...@@ -305,14 +310,14 @@ class DarkNet(nn.Layer):
name=name)) name=name))
self.darknet_conv_block_list.append(conv_block) self.darknet_conv_block_list.append(conv_block)
if i in return_idx: if i in return_idx:
self._out_channels.append(64 * (2**i)) self._out_channels.append(int(ch_in[i]))
for i in range(num_stages - 1): for i in range(num_stages - 1):
down_name = 'stage.{}.downsample'.format(i) down_name = 'stage.{}.downsample'.format(i)
downsample = self.add_sublayer( downsample = self.add_sublayer(
down_name, down_name,
DownSample( DownSample(
ch_in=32 * (2**(i + 1)), ch_in=int(ch_in[i]),
ch_out=32 * (2**(i + 2)), ch_out=int(ch_in[i+1]),
norm_type=norm_type, norm_type=norm_type,
norm_decay=norm_decay, norm_decay=norm_decay,
freeze_norm=freeze_norm, freeze_norm=freeze_norm,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册