提交 5b7ae268 编写于 作者: M Megvii Engine Team

feat(mge): do not export F.nn.roi_pooling

GitOrigin-RevId: 8a07dd1da4af3a1322455cf4fcb49a9a4bbe018e
上级 d22d1676
...@@ -45,7 +45,6 @@ __all__ = [ ...@@ -45,7 +45,6 @@ __all__ = [
"max_pool2d", "max_pool2d",
"one_hot", "one_hot",
"prelu", "prelu",
"roi_pooling",
"softmax", "softmax",
"softplus", "softplus",
"svd", "svd",
...@@ -1324,7 +1323,7 @@ def roi_pooling( ...@@ -1324,7 +1323,7 @@ def roi_pooling(
np.random.seed(42) np.random.seed(42)
inp = tensor(np.random.randn(1, 1, 128, 128)) inp = tensor(np.random.randn(1, 1, 128, 128))
rois = tensor(np.random.random((4, 5))) rois = tensor(np.random.random((4, 5)))
y = F.roi_pooling(inp, rois, (2, 2)) y = F.nn.roi_pooling(inp, rois, (2, 2))
print(y.numpy()[0]) print(y.numpy()[0])
Outputs: Outputs:
......
...@@ -193,7 +193,7 @@ def test_roi_pooling(): ...@@ -193,7 +193,7 @@ def test_roi_pooling():
inp_feat, rois = _gen_roi_inp() inp_feat, rois = _gen_roi_inp()
grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat)) grad = Grad().wrt(inp_feat, callback=_save_to(inp_feat))
output_shape = (7, 7) output_shape = (7, 7)
out_feat = F.roi_pooling( out_feat = F.nn.roi_pooling(
inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4, inp_feat, rois, output_shape=output_shape, mode="max", scale=1.0 / 4,
) )
assert make_shape_tuple(out_feat.shape) == ( assert make_shape_tuple(out_feat.shape) == (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册