未验证 提交 80f23a85 编写于 作者: W WJJ1995 提交者: GitHub

fixed MaxUnpool bug (#662)

上级 90fdab16
...@@ -8,10 +8,9 @@ torch.nn.MaxUnpool1d(kernel_size, stride=None, padding=0) ...@@ -8,10 +8,9 @@ torch.nn.MaxUnpool1d(kernel_size, stride=None, padding=0)
```python ```python
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
TYPE_MAPPER = {"fp16": "float16", "fp32": "float32", "fp64": "float64"}
# 定义MaxUnpool1D # 定义MaxUnpool1D
class MaxUnpool1D(paddle.nn.Layer): class MaxUnpool1D(nn.Layer):
def __init__(self, kernel_size, stride=None, padding=0): def __init__(self, kernel_size, stride=None, padding=0):
super().__init__() super().__init__()
if isinstance(stride, int): if isinstance(stride, int):
...@@ -49,7 +48,7 @@ class MaxUnpool1D(paddle.nn.Layer): ...@@ -49,7 +48,7 @@ class MaxUnpool1D(paddle.nn.Layer):
flatten_indices = paddle.flatten(indices) flatten_indices = paddle.flatten(indices)
flatten_input = paddle.flatten(input) flatten_input = paddle.flatten(input)
for i in range(flatten_indices.shape[0]): for i in range(flatten_indices.shape[0]):
flatten_out[flatten_indices[i].tolist()] = flatten_input[i].tolist() flatten_out[int(flatten_indices[i])] = flatten_input[i]
out = paddle.reshape(flatten_out, out.shape) out = paddle.reshape(flatten_out, out.shape)
return out return out
......
...@@ -9,10 +9,9 @@ torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0) ...@@ -9,10 +9,9 @@ torch.nn.MaxUnpool2d(kernel_size, stride=None, padding=0)
```python ```python
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
TYPE_MAPPER = {"fp16": "float16", "fp32": "float32", "fp64": "float64"}
# 定义MaxUnpool2D # 定义MaxUnpool2D
class MaxUnpool2D(paddle.nn.Layer): class MaxUnpool2D(nn.Layer):
def __init__(self, kernel_size, stride=None, padding=0): def __init__(self, kernel_size, stride=None, padding=0):
super().__init__() super().__init__()
if isinstance(stride, int): if isinstance(stride, int):
...@@ -41,7 +40,6 @@ class MaxUnpool2D(paddle.nn.Layer): ...@@ -41,7 +40,6 @@ class MaxUnpool2D(paddle.nn.Layer):
if len(output_size) == len(self.kernel_size) + 2: if len(output_size) == len(self.kernel_size) + 2:
output_size = output_size[2:] output_size = output_size[2:]
t = str(input.dtype).lower().strip().split(".")[-1] t = str(input.dtype).lower().strip().split(".")[-1]
t = TYPE_MAPPER[t]
out = paddle.zeros(output_size, dtype=t) out = paddle.zeros(output_size, dtype=t)
flatten_out = paddle.flatten(out) flatten_out = paddle.flatten(out)
for i in range(indices.shape[0]): for i in range(indices.shape[0]):
...@@ -53,7 +51,7 @@ class MaxUnpool2D(paddle.nn.Layer): ...@@ -53,7 +51,7 @@ class MaxUnpool2D(paddle.nn.Layer):
flatten_indices = paddle.flatten(indices) flatten_indices = paddle.flatten(indices)
flatten_input = paddle.flatten(input) flatten_input = paddle.flatten(input)
for i in range(flatten_indices.shape[0]): for i in range(flatten_indices.shape[0]):
flatten_out[flatten_indices[i].tolist()] = flatten_input[i].tolist() flatten_out[int(flatten_indices[i])] = flatten_input[i]
out = paddle.reshape(flatten_out, out.shape) out = paddle.reshape(flatten_out, out.shape)
return out return out
......
...@@ -8,10 +8,9 @@ torch.nn.MaxUnpool3d(kernel_size, stride=None, padding=0) ...@@ -8,10 +8,9 @@ torch.nn.MaxUnpool3d(kernel_size, stride=None, padding=0)
```python ```python
import paddle import paddle
import paddle.nn as nn import paddle.nn as nn
TYPE_MAPPER = {"fp16": "float16", "fp32": "float32", "fp64": "float64"}
# 定义MaxUnpool3D # 定义MaxUnpool3D
class MaxUnpool3D(paddle.nn.Layer): class MaxUnpool3D(nn.Layer):
def __init__(self, kernel_size, stride=None, padding=0): def __init__(self, kernel_size, stride=None, padding=0):
super().__init__() super().__init__()
if isinstance(stride, int): if isinstance(stride, int):
...@@ -55,7 +54,7 @@ class MaxUnpool3D(paddle.nn.Layer): ...@@ -55,7 +54,7 @@ class MaxUnpool3D(paddle.nn.Layer):
flatten_indices = paddle.flatten(indices) flatten_indices = paddle.flatten(indices)
flatten_input = paddle.flatten(input) flatten_input = paddle.flatten(input)
for i in range(flatten_indices.shape[0]): for i in range(flatten_indices.shape[0]):
flatten_out[flatten_indices[i].tolist()] = flatten_input[i].tolist() flatten_out[int(flatten_indices[i])] = flatten_input[i]
out = paddle.reshape(flatten_out, out.shape) out = paddle.reshape(flatten_out, out.shape)
return out return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册