提交 bebb2cf4 编写于 作者: M Megvii Engine Team

Merge pull request #428 from P2Oileen:fix-pad

GitOrigin-RevId: f33ea46ad62025f6531d98311df3d39462e7d862
...@@ -1708,7 +1708,7 @@ def sliding_window_transpose( ...@@ -1708,7 +1708,7 @@ def sliding_window_transpose(
def pad( def pad(
src: Tensor, src: Tensor,
pad_witdth: Tuple[Tuple[int, int], ...], pad_width: Tuple[Tuple[int, int], ...],
mode: str = "constant", mode: str = "constant",
constant_value: float = 0.0, constant_value: float = 0.0,
) -> Tensor: ) -> Tensor:
...@@ -1723,9 +1723,9 @@ def pad( ...@@ -1723,9 +1723,9 @@ def pad(
if mode.lower() == "edge": if mode.lower() == "edge":
mode = "replicate" mode = "replicate"
for i in range(0, len(pad_witdth)): for i in range(0, len(pad_width)):
p_offsets[i * 2] = pad_witdth[i][0] p_offsets[i * 2] = pad_width[i][0]
p_offsets[i * 2 + 1] = pad_witdth[i][1] p_offsets[i * 2 + 1] = pad_width[i][1]
op = builtin.Padding( op = builtin.Padding(
front_offset_dim0=p_offsets[0], front_offset_dim0=p_offsets[0],
......
...@@ -12,16 +12,16 @@ class Pad(Module): ...@@ -12,16 +12,16 @@ class Pad(Module):
def __init__( def __init__(
self, self,
pad_witdth: Tuple[Tuple[int, int], ...], pad_width: Tuple[Tuple[int, int], ...],
mode: str = "constant", mode: str = "constant",
constant_val: float = 0.0, constant_val: float = 0.0,
): ):
super().__init__() super().__init__()
self.pad_width = pad_witdth self.pad_width = pad_width
self.mode = mode self.mode = mode
self.pad_val = constant_val self.pad_val = constant_val
def forward(self, src): def forward(self, src):
return nn.pad( return nn.pad(
src, pad_witdth=self.pad_width, mode=self.mode, constant_value=self.pad_val src, pad_width=self.pad_width, mode=self.mode, constant_value=self.pad_val
) )
...@@ -162,3 +162,11 @@ def tensor_gen_func_loader(expr): ...@@ -162,3 +162,11 @@ def tensor_gen_func_loader(expr):
else: else:
device = None device = None
expr.set_args_kwargs(shape, dtype=dtype, device=device) expr.set_args_kwargs(shape, dtype=dtype, device=device)
@register_functional_loader(("megengine.functional.nn", "pad"))
def pad_func_loader(expr):
if "pad_witdth" in expr.kwargs:
kwargs = expr.kwargs
kwargs["pad_width"] = kwargs.pop("pad_witdth")
expr.set_args_kwargs(*expr.args, **kwargs)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册