提交 43cb305e 编写于 作者: M Megvii Engine Team

feat(mge): add imperative pad

GitOrigin-RevId: de79de536fe1376f4a52dc7b15fae4b8f493d61f
上级 567586a0
......@@ -1592,6 +1592,57 @@ def sliding_window_transpose(
return output
def pad(
src: Tensor,
pad_witdth: Tuple[Tuple[int, int], ...],
mode: str = "CONSTANT",
constant_value: float = 0.0,
) -> Tensor:
"""
pad
"""
p_offsets = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
assert mode in [
"constant",
"CONSTANT",
"edge",
"EDGE",
"replicate",
"REPLICATE",
"reflect",
"REFLECT",
]
if mode.lower() == "edge":
mode = "replicate"
for i in range(0, len(pad_witdth)):
p_offsets[i * 2] = pad_witdth[i][0]
p_offsets[i * 2 + 1] = pad_witdth[i][1]
op = builtin.Padding(
front_offset_dim0=p_offsets[0],
front_offset_dim1=p_offsets[2],
front_offset_dim2=p_offsets[4],
front_offset_dim3=p_offsets[6],
front_offset_dim4=p_offsets[8],
front_offset_dim5=p_offsets[10],
front_offset_dim6=p_offsets[12],
back_offset_dim0=p_offsets[1],
back_offset_dim1=p_offsets[3],
back_offset_dim2=p_offsets[5],
back_offset_dim3=p_offsets[7],
back_offset_dim4=p_offsets[9],
back_offset_dim5=p_offsets[11],
back_offset_dim6=p_offsets[13],
padding_val=constant_value,
padding_mode=mode.upper(),
)
(output,) = apply(op, src)
return output
interpolate = deprecated_func("1.3", "megengine.functional.vision", "interpolate", True)
roi_pooling = deprecated_func("1.3", "megengine.functional.vision", "roi_pooling", True)
roi_align = deprecated_func("1.3", "megengine.functional.vision", "roi_align", True)
......
......@@ -31,6 +31,7 @@ from .identity import Identity
from .linear import Linear
from .module import Module
from .normalization import GroupNorm, InstanceNorm, LayerNorm
from .padding import Pad
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub
from .sequential import Sequential
......
from typing import Tuple
from ..functional import nn
from .module import Module
class Pad(Module):
def __init__(
self,
pad_witdth: Tuple[Tuple[int, int], ...],
mode: str = "CONSTANT",
constant_val: float = 0.0,
):
super().__init__()
self.pad_width = pad_witdth
self.mode = mode
self.pad_val = constant_val
def forward(self, src):
return nn.pad(
src, pad_witdth=self.pad_width, mode=self.mode, constant_value=self.pad_val
)
......@@ -1062,3 +1062,22 @@ def test_sliding_window_transpose():
dilation=(dh, dw),
)
np.testing.assert_equal(gt_out, out.numpy())
def test_pad():
src = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32)
dst = np.pad(src, ((2, 2), (2, 2)), "constant")
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT")
np.testing.assert_allclose(res, dst, atol=1e-5)
dst = np.pad(src, ((2, 2), (2, 2)), "constant", constant_values=3)
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "CONSTANT", constant_value=3)
np.testing.assert_allclose(res, dst, atol=1e-5)
dst = np.pad(src, ((2, 2), (2, 2)), "edge")
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "EDGE")
np.testing.assert_allclose(res, dst, atol=1e-5)
dst = np.pad(src, ((2, 2), (2, 2)), "reflect")
res = F.nn.pad(tensor(src), ((2, 2), (2, 2)), "REFLECT")
np.testing.assert_allclose(res, dst, atol=1e-5)
......@@ -660,4 +660,12 @@ OP_TRAIT_REG(Cumsum, Cumsum).apply_on_var_node(apply_on_var_node).fallback();
} // namespace cumsum
} // namespace
namespace padding {
auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const Padding&>(def);
mgb_assert(inputs.size() == 1);
return opr::Padding::make(inputs[0], op.param());
}
OP_TRAIT_REG(Padding, Padding).apply_on_var_node(apply_on_var_node).fallback();
} // namespace padding
} // namespace mgb::imperative
......@@ -389,4 +389,6 @@ def Split: MgbHashableOp<"Split", [EmptyParam]> {
);
}
def Padding: MgbHashableOp<"Padding", [PaddingParam]>;
#endif // MGB_OPS
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册