未验证 提交 fd13584e 编写于 作者: C Chang Xu 提交者: GitHub

asymmetric_kernel (#969)

* asymmetric_kernel

* asymmetric_kernel
上级 f895aebe
...@@ -214,6 +214,10 @@ class SuperConv2D(nn.Conv2D): ...@@ -214,6 +214,10 @@ class SuperConv2D(nn.Conv2D):
setattr(self, name, param) setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size): def get_active_filter(self, in_nc, out_nc, kernel_size):
### Unsupport for asymmetric kernels
if self._kernel_size[0] != self._kernel_size[1]:
return self.weight[:out_nc, :in_nc, :, :]
start, end = compute_start_end(self._kernel_size[0], kernel_size) start, end = compute_start_end(self._kernel_size[0], kernel_size)
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter ### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
filters = self.weight[:out_nc, :in_nc, start:end, start:end] filters = self.weight[:out_nc, :in_nc, start:end, start:end]
...@@ -288,9 +292,14 @@ class SuperConv2D(nn.Conv2D): ...@@ -288,9 +292,14 @@ class SuperConv2D(nn.Conv2D):
out_nc = int(channel) out_nc = int(channel)
else: else:
out_nc = self._out_channels out_nc = self._out_channels
ks = int(self._kernel_size[0]) if kernel_size == None else int( ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size) kernel_size)
if kernel_size is not None and self._kernel_size[
0] != self._kernel_size[1]:
_logger.error("Searching for asymmetric kernels is NOT supported")
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc) out_nc)
...@@ -518,6 +527,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -518,6 +527,9 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
setattr(self, name, param) setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size): def get_active_filter(self, in_nc, out_nc, kernel_size):
### Unsupport for asymmetric kernels
if self._kernel_size[0] != self._kernel_size[1]:
return self.weight[:out_nc, :in_nc, :, :]
start, end = compute_start_end(self._kernel_size[0], kernel_size) start, end = compute_start_end(self._kernel_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end] filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._kernel_size[ if self.transform_kernel != False and kernel_size < self._kernel_size[
...@@ -600,6 +612,10 @@ class SuperConv2DTranspose(nn.Conv2DTranspose): ...@@ -600,6 +612,10 @@ class SuperConv2DTranspose(nn.Conv2DTranspose):
ks = int(self._kernel_size[0]) if kernel_size == None else int( ks = int(self._kernel_size[0]) if kernel_size == None else int(
kernel_size) kernel_size)
if kernel_size is not None and self._kernel_size[
0] != self._kernel_size[1]:
_logger.error("Searching for asymmetric kernels is NOT supported")
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc) out_nc)
......
...@@ -213,6 +213,9 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -213,6 +213,9 @@ class SuperConv2D(fluid.dygraph.Conv2D):
setattr(self, name, param) setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size): def get_active_filter(self, in_nc, out_nc, kernel_size):
### Unsupport for asymmetric kernels
if self._filter_size[0] != self._filter_size[1]:
return self.weight[:out_nc, :in_nc, :, :]
start, end = compute_start_end(self._filter_size[0], kernel_size) start, end = compute_start_end(self._filter_size[0], kernel_size)
### if NOT transform kernel, intercept a center filter with kernel_size from largest filter ### if NOT transform kernel, intercept a center filter with kernel_size from largest filter
filters = self.weight[:out_nc, :in_nc, start:end, start:end] filters = self.weight[:out_nc, :in_nc, start:end, start:end]
...@@ -285,6 +288,10 @@ class SuperConv2D(fluid.dygraph.Conv2D): ...@@ -285,6 +288,10 @@ class SuperConv2D(fluid.dygraph.Conv2D):
ks = int(self._filter_size[0]) if kernel_size == None else int( ks = int(self._filter_size[0]) if kernel_size == None else int(
kernel_size) kernel_size)
if kernel_size is not None and self._filter_size[
0] != self._filter_size[1]:
_logger.error("Searching for asymmetric kernels is NOT supported")
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc) out_nc)
...@@ -513,6 +520,9 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -513,6 +520,9 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
setattr(self, name, param) setattr(self, name, param)
def get_active_filter(self, in_nc, out_nc, kernel_size): def get_active_filter(self, in_nc, out_nc, kernel_size):
### Unsupport for asymmetric kernels
if self._filter_size[0] != self._filter_size[1]:
return self.weight[:out_nc, :in_nc, :, :]
start, end = compute_start_end(self._filter_size[0], kernel_size) start, end = compute_start_end(self._filter_size[0], kernel_size)
filters = self.weight[:in_nc, :out_nc, start:end, start:end] filters = self.weight[:in_nc, :out_nc, start:end, start:end]
if self.transform_kernel != False and kernel_size < self._filter_size[ if self.transform_kernel != False and kernel_size < self._filter_size[
...@@ -584,6 +594,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose): ...@@ -584,6 +594,10 @@ class SuperConv2DTranspose(fluid.dygraph.Conv2DTranspose):
ks = int(self._filter_size[0]) if kernel_size == None else int( ks = int(self._filter_size[0]) if kernel_size == None else int(
kernel_size) kernel_size)
if kernel_size is not None and self._filter_size[
0] != self._filter_size[1]:
_logger.error("Searching for asymmetric kernels is NOT supported")
groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc, groups, weight_in_nc, weight_out_nc = self.get_groups_in_out_nc(in_nc,
out_nc) out_nc)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册